├── .gitignore ├── LICENSE ├── README.md ├── generate.py ├── loader.py ├── model.py ├── preprocess.py ├── train.py └── trimmer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | checkpoint/* 3 | generated/* 4 | .idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tacotron_pytorch 2 | Pytorch implementation of Tacotron: A Fully End-to-End Text-To-Speech Synthesis Model 3 | 4 | https://arxiv.org/abs/1703.10135 5 | 6 | ## Requirements 7 | * pytorch 8 | * librosa 9 | * py-webrtcvad 10 | 11 | ## Data 12 | Please register to use Blizzard Challenge data set. (http://www.cstr.ed.ac.uk/projects/blizzard/) 13 | 14 | In the code, the option 'blizzard' is for the Blizzard Challenge data of 2013. 15 | 16 | The option 'etri' is for Korean TTS dataset published by 'ETRI' and you need to buy license to use it. 17 | 18 | You need to download and unzip the data from the website. 19 | 20 | Then, set paths in the codes (train.py, preprocess.py, generate.py) accordingly. (find 'dir_' and change the following lines) 21 | 22 | ## How to run 23 | * Please refer the code to see what options/hyperparameters are available 24 | 1. Prepare data and preprocess the data (ex. blizzard) by running: preprocess.py --data 'blizzard' 25 | (You may want to trim silences in audio files before preprocessing. Please use trimmer.py) 26 | 2. Run 'train.py' with arguments. 27 | 3. After training, run 'generate.py' with arguments to get generated audio file. 28 | 29 | 30 | ## Comment 31 | Contributions and comments are always welcome. 32 | 33 | I refered https://github.com/keithito/tacotron for the preprocessing code. Thank you. 34 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals, print_function, division 3 | import argparse, librosa, copy, shutil, pdb, multiprocessing, re 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | 10 | from model import Tacotron as Tacotron 11 | from loader import DataLoader 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser(description='training script') 16 | # data load 17 | parser.add_argument('--data', type=str, default='blizzard', help='blizzard / nancy') 18 | parser.add_argument('--batch_size', type=int, default=6, help='batch size') 19 | parser.add_argument('--text_limit', type=int, default=1500, help='maximum length of text to include in training set') 20 | parser.add_argument('--wave_limit', type=int, default=800, help='maximum length of spectrogram to include in training set') 21 | parser.add_argument('--shuffle_data', type=int, default=0, help='whether to shuffle data loader') 22 | parser.add_argument('--batch_idx', type=int, default=0, help='n-th batch of the dataset') 23 | parser.add_argument('--load_queue_size', type=int, default=1, help='maximum number of batches to load on the memory') 24 | parser.add_argument('--n_workers', type=int, default=1, help='number of workers used in data loader') 25 | # generation option 26 | parser.add_argument('--exp_no', type=int, default=0, help='') 27 | parser.add_argument('--out_dir', type=str, default='generated', help='') 28 | parser.add_argument('--init_from', type=str, default='', help='load parameters from...') 29 | parser.add_argument('--caption', type=str, default='', help='text to generate speech') 30 | parser.add_argument('--teacher_forcing_ratio', type=float, default=0, help='value between 0~1, use this for scheduled sampling') 31 | # audio related option 32 | parser.add_argument('--n_fft', type=int, default=2048, help='fft bin size') 33 | parser.add_argument('--sample_rate', type=int, default=16000, help='sampling rate') 34 | parser.add_argument('--frame_len_inMS', type=int, default=50, help='used to determine window size of fft') 35 | parser.add_argument('--frame_shift_inMS', type=int, default=12.5, help='used to determine stride in sfft') 36 | parser.add_argument('--num_recon_iters', type=int, default=50, help='# of iteration in griffin-lim recon') 37 | # misc 38 | parser.add_argument('--gpu', type=int, nargs='+', help='index of gpu machines to run') 39 | parser.add_argument('--seed', type=int, default=0, help='random seed') 40 | new_args = vars(parser.parse_args()) 41 | 42 | # load and override some arguments 43 | checkpoint = torch.load(new_args['init_from'], map_location=lambda storage, loc: storage) 44 | args = checkpoint['args'] 45 | for i in new_args: 46 | args.__dict__[i] = new_args[i] 47 | 48 | torch.manual_seed(args.seed) 49 | 50 | # set dataset option 51 | if args.data == 'blizzard': 52 | args.dir_bin = '/data2/lyg0722/TTS_corpus/blizzard/segmented/bin/' 53 | elif args.data == 'etri': 54 | args.dir_bin = '/data2/lyg0722/TTS_corpus/etri/bin/' 55 | else: 56 | print('no dataset') 57 | return 58 | 59 | if args.gpu is None: 60 | args.use_gpu = False 61 | args.gpu = [] 62 | else: 63 | args.use_gpu = True 64 | torch.cuda.manual_seed(0) 65 | torch.cuda.set_device(args.gpu[0]) 66 | 67 | model = Tacotron(args) 68 | criterion_mel = nn.L1Loss(size_average=False) 69 | criterion_lin = nn.L1Loss(size_average=False) 70 | 71 | window_len = int(np.ceil(args.frame_len_inMS * args.sample_rate / 1000)) 72 | hop_length = int(np.ceil(args.frame_shift_inMS * args.sample_rate / 1000)) 73 | 74 | if args.init_from: 75 | model.load_state_dict(checkpoint['state_dict']) 76 | print('loaded checkpoint %s' % (args.init_from)) 77 | 78 | model = model.eval() 79 | 80 | if args.use_gpu: 81 | model = model.cuda() 82 | criterion_mel = criterion_mel.cuda() 83 | criterion_lin = criterion_lin.cuda() 84 | 85 | if args.caption: 86 | text_raw = args.caption 87 | 88 | if args.data == 'etri': 89 | text_raw = decompose_hangul(text_raw) # For Korean dataset 90 | 91 | vocab_dict = torch.load(args.dir_bin + 'vocab.t7') 92 | 93 | enc_input = [vocab_dict[i] for i in text_raw] 94 | enc_input = enc_input + [0] # null-padding at tail 95 | text_lengths = [len(enc_input)] 96 | enc_input = Variable(torch.LongTensor(enc_input).view(1,-1)) 97 | 98 | dec_input = torch.Tensor(1, 1, args.dec_out_size).fill_(0) # null-padding for start flag 99 | dec_input = Variable(dec_input) 100 | wave_lengths = [args.wave_limit] # TODO: use later... 101 | 102 | prev_h = (None, None, None) # set prev_h = h_0 when new sentences are loaded 103 | 104 | if args.gpu: 105 | enc_input = enc_input.cuda() 106 | dec_input = dec_input.cuda() 107 | 108 | _, pred_lin, prev_h = model(enc_input, dec_input, wave_lengths, text_lengths, prev_h) 109 | 110 | # start generation 111 | wave = spectrogram2wav( 112 | pred_lin.data.view(-1, args.post_out_size).cpu().numpy(), 113 | n_fft=args.n_fft, 114 | win_length=window_len, 115 | hop_length=hop_length, 116 | num_iters=args.num_recon_iters 117 | ) 118 | 119 | # write to file 120 | outpath1 = '%s/%s_%s.wav' % (args.out_dir, args.exp_no, args.caption) 121 | outpath2 = '%s/%s_%s.png' % (args.out_dir, args.exp_no, args.caption) 122 | librosa.output.write_wav(outpath1, wave, 16000) 123 | saveAttention(text_raw, torch.cat(model.attn_weights, dim=-1).squeeze(), outpath2) 124 | else: 125 | loader = DataLoader(args) 126 | args.vocab_size = loader.get_num_vocab() 127 | 128 | for iter in range(1, loader.iter_per_epoch + 1): 129 | if loader.is_subbatch_end: 130 | prev_h = (None, None, None) # set prev_h = h_0 when new sentences are loaded 131 | 132 | for i in range(args.batch_idx): 133 | loader.next_batch('train') 134 | 135 | enc_input, target_mel, target_lin, wave_lengths, text_lengths = loader.next_batch('train') 136 | enc_input = Variable(enc_input, volatile=True) 137 | target_mel = Variable(target_mel, volatile=True) 138 | target_lin = Variable(target_lin, volatile=True) 139 | 140 | prev_h = loader.mask_prev_h(prev_h) 141 | 142 | if args.gpu: 143 | enc_input = enc_input.cuda() 144 | target_mel = target_mel.cuda() 145 | target_lin = target_lin.cuda() 146 | 147 | pred_mel, pred_lin, prev_h = model(enc_input, target_mel[:, :-1], wave_lengths, text_lengths, prev_h) 148 | 149 | loss_mel = criterion_mel(pred_mel, target_mel[:, 1:]) \ 150 | .div(max(wave_lengths) * args.batch_size * args.dec_out_size) 151 | loss_linear = criterion_lin(pred_lin, target_lin[:, 1:]) \ 152 | .div(max(wave_lengths) * args.batch_size * args.post_out_size) 153 | loss = torch.sum(loss_mel + loss_linear) 154 | 155 | print('loss:' , loss.data[0]) 156 | 157 | attentions = torch.cat(model.attn_weights, dim=-1) 158 | 159 | # write to file 160 | for n in range(enc_input.size(0)): 161 | wave = spectrogram2wav( 162 | pred_lin.data[n].view(-1, args.post_out_size).cpu().numpy(), 163 | n_fft=args.n_fft, 164 | win_length=window_len, 165 | hop_length=hop_length, 166 | num_iters=args.num_recon_iters 167 | ) 168 | outpath1 = '%s/%s_%s_%s.wav' % (args.out_dir, args.exp_no, n, args.caption) 169 | librosa.output.write_wav(outpath1, wave, 16000) 170 | outpath2 = '%s/%s_%s_%s.png' % (args.out_dir, args.exp_no, n, args.caption) 171 | saveAttention(None, attentions[n], outpath2) 172 | 173 | 174 | # showPlot(plot_losses) 175 | break 176 | 177 | ###################################################################### 178 | # This is a helper function to print time elapsed and estimated time 179 | # remaining given the current time and progress %. 180 | # 181 | 182 | import time 183 | import math 184 | 185 | 186 | def asMinutes(s): 187 | m = math.floor(s / 60) 188 | s -= m * 60 189 | return '%dm %ds' % (m, s) 190 | 191 | 192 | def timeSince(since, percent): 193 | now = time.time() 194 | s = now - since 195 | es = s / (percent) 196 | rs = es - s 197 | return '%s (- %s)' % (asMinutes(s), asMinutes(rs)) 198 | 199 | 200 | def saveAttention(input_sentence, attentions, outpath): 201 | # Set up figure with colorbar 202 | import matplotlib 203 | matplotlib.use('Agg') 204 | import matplotlib.pyplot as plt 205 | import matplotlib.ticker as ticker 206 | 207 | fig = plt.figure(figsize=(24,10), ) 208 | ax = fig.add_subplot(111) 209 | cax = ax.matshow(attentions.cpu().numpy(), cmap='bone') 210 | fig.colorbar(cax) 211 | 212 | if input_sentence: 213 | # Set up axes 214 | ax.set_yticklabels([' '] + list(input_sentence) + [' ']) 215 | # Show label at every tick 216 | ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) 217 | 218 | plt.tight_layout() 219 | plt.savefig(outpath) 220 | plt.close('all') 221 | 222 | 223 | def spectrogram2wav(spectrogram, n_fft, win_length, hop_length, num_iters): 224 | ''' 225 | spectrogram: [t, f], i.e. [t, nfft // 2 + 1] 226 | ''' 227 | min_level_db = -100 228 | ref_level_db = 20 229 | 230 | spec = spectrogram.T 231 | # denormalize 232 | spec = (np.clip(spec, 0, 1) * - min_level_db) + min_level_db 233 | spec = spec + ref_level_db 234 | 235 | # Convert back to linear 236 | spec = np.power(10.0, spec * 0.05) 237 | 238 | return _griffin_lim(spec ** 1.5, n_fft, win_length, hop_length, num_iters) # Reconstruct phase 239 | 240 | 241 | def _griffin_lim(S, n_fft, win_length, hop_length, num_iters): 242 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) 243 | S_complex = np.abs(S).astype(np.complex) 244 | for i in range(num_iters): 245 | if i > 0: 246 | angles = np.exp(1j * np.angle(librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length))) 247 | y = librosa.istft(S_complex * angles, hop_length=hop_length, win_length=win_length) 248 | return y 249 | 250 | 251 | def decompose_hangul(text): 252 | """ 253 | Code from: https://github.com/neotune/python-korean-handler 254 | """ 255 | 256 | # 유니코드 한글 시작 : 44032, 끝 : 55199 257 | Start_Code, ChoSung, JungSung = 44032, 588, 28 258 | 259 | # 초성 리스트. 00 ~ 18 260 | ChoSung_LIST = ['ㄱ', 'ㄲ', 'ㄴ', 'ㄷ', 'ㄸ', 'ㄹ', 'ㅁ', 'ㅂ', 'ㅃ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅉ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ'] 261 | 262 | # 중성 리스트. 00 ~ 20 263 | JungSung_LIST = ['ㅏ', 'ㅐ', 'ㅑ', 'ㅒ', 'ㅓ', 'ㅔ', 'ㅕ', 'ㅖ', 'ㅗ', 'ㅘ', 'ㅙ', 'ㅚ', 'ㅛ', 'ㅜ', 'ㅝ', 'ㅞ', 'ㅟ', 'ㅠ', 'ㅡ', 'ㅢ', 264 | 'ㅣ'] 265 | 266 | # 종성 리스트. 00 ~ 27 + 1(1개 없음) 267 | JongSung_LIST = ['', 'ㄱ', 'ㄲ', 'ㄳ', 'ㄴ', 'ㄵ', 'ㄶ', 'ㄷ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㄼ', 'ㄽ', 'ㄾ', 'ㄿ', 'ㅀ', 'ㅁ', 'ㅂ', 'ㅄ', 'ㅅ', 268 | 'ㅆ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ'] 269 | 270 | line_dec = "" 271 | line = list(text.strip()) 272 | 273 | for keyword in line: 274 | # 한글 여부 check 후 분리: ㄱ~ㅎ + ㅏ~ㅣ+ 가~힣 275 | if re.match('.*[ㄱ-ㅎㅏ-ㅣ가-힣]+.*', keyword) is not None: 276 | char_code = ord(keyword) - Start_Code 277 | char1 = int(char_code / ChoSung) 278 | line_dec += ChoSung_LIST[char1] 279 | #print('초성 : {}'.format(CHOSUNG_LIST[char1])) 280 | char2 = int((char_code - (ChoSung * char1)) / JungSung) 281 | line_dec += JungSung_LIST[char2] 282 | #print('중성 : {}'.format(JUNGSUNG_LIST[char2])) 283 | char3 = int((char_code - (ChoSung * char1) - (JungSung * char2))) 284 | line_dec += JongSung_LIST[char3] 285 | #print('종성 : {}'.format(JONGSUNG_LIST[char3])) 286 | else: 287 | line_dec += keyword 288 | 289 | return line_dec 290 | 291 | if __name__ == '__main__': 292 | try: 293 | main() 294 | finally: 295 | for p in multiprocessing.active_children(): 296 | p.terminate() 297 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals, print_function, division 3 | from torch.multiprocessing import Process, Queue, Pool 4 | from torch.autograd import Variable 5 | from functools import partial 6 | import math, torch, pickle 7 | import os.path 8 | 9 | 10 | class DataLoader(): 11 | def __init__(self, args): 12 | self.dir_bin = args.dir_bin 13 | line_load_list = self.dir_bin + 'line_load_list.t7' 14 | vocab_file = self.dir_bin + 'vocab.t7' 15 | assert os.path.isfile(self.dir_bin + 'specM.bin') 16 | assert os.path.isfile(self.dir_bin + 'specL.bin') 17 | assert os.path.isfile(self.dir_bin + 'text.bin') 18 | 19 | self.batch_size = args.batch_size 20 | self.trunc_size = args.trunc_size 21 | self.r_factor = args.r_factor 22 | self.dec_out_size = args.dec_out_size 23 | self.post_out_size = args.post_out_size 24 | self.shuffle_data = True if args.shuffle_data == 1 else False 25 | self.iter_per_epoch = None 26 | self.is_subbatch_end = True 27 | self.curr_split = None 28 | self.vocab_size = None 29 | 30 | self.process = None 31 | self.queue = Queue(maxsize=args.load_queue_size) 32 | self.n_workers = args.n_workers 33 | 34 | self.use_gpu = args.use_gpu 35 | self.num_gpu = len(args.gpu) if len(args.gpu) > 0 else 1 36 | self.pinned_memory = True if args.pinned_memory == 1 and self.use_gpu else False 37 | 38 | self.vocab_size = self.get_num_vocab(vocab_file) 39 | text_limit = args.text_limit 40 | wave_limit = args.wave_limit 41 | 42 | # col1: idx / col2: wave_length / col3: text_length 43 | # col4: offset_M / col5: offset_L / col6: offset_T 44 | self.load_list = torch.load(line_load_list) 45 | spec_len_list = self.load_list[:, 1].clone() 46 | text_len_list = self.load_list[:, 2].clone() 47 | 48 | # exclude files whose wave length exceeds wave_limit 49 | sort_length, sort_idx = spec_len_list.sort() 50 | text_len_list = torch.gather(text_len_list, 0, sort_idx) 51 | sort_idx = sort_idx.view(-1, 1).expand_as(self.load_list) 52 | self.load_list = torch.gather(self.load_list, 0, sort_idx) 53 | 54 | end_idx = sort_length.le(wave_limit).sum() 55 | spec_len_list = sort_length[:end_idx] 56 | text_len_list = text_len_list[:end_idx] 57 | self.load_list = self.load_list[:end_idx] 58 | 59 | # exclude files whose text length exceeds text_limit 60 | sort_length, sort_idx = text_len_list.sort() 61 | spec_len_list = torch.gather(spec_len_list, 0, sort_idx) 62 | sort_idx = sort_idx.view(-1, 1).expand_as(self.load_list) 63 | self.load_list = torch.gather(self.load_list, 0, sort_idx) 64 | 65 | end_idx = sort_length.le(text_limit).sum() 66 | end_idx = end_idx - (end_idx % self.batch_size) # drop residual data 67 | text_len_list = sort_length[:end_idx] 68 | spec_len_list = spec_len_list[:end_idx] 69 | self.load_list = self.load_list[:end_idx] 70 | 71 | # sort by wave length 72 | _, sort_idx = spec_len_list.sort(0, descending=True) 73 | text_len_list = torch.gather(text_len_list, 0, sort_idx) 74 | sort_idx = sort_idx.view(-1, 1).expand_as(self.load_list) 75 | self.load_list = torch.gather(self.load_list, 0, sort_idx) 76 | 77 | # sort by text length in each batch (PackedSequence requires it) 78 | num_batches_per_epoch = self.load_list.size(0) // self.batch_size 79 | text_len_list = text_len_list.view(num_batches_per_epoch, -1) 80 | self.load_list = self.load_list.view(num_batches_per_epoch, -1, self.load_list.size(1)) 81 | sort_length, sort_idx = text_len_list.sort(1, descending=True) 82 | sort_idx = sort_idx.view(num_batches_per_epoch, -1, 1).expand_as(self.load_list) 83 | self.load_list = torch.gather(self.load_list, 1, sort_idx) 84 | 85 | # shuffle while preserving order in a batch 86 | if self.shuffle_data: 87 | _, sort_idx = torch.randn(num_batches_per_epoch).sort() 88 | sort_idx = sort_idx.view(-1, 1, 1).expand_as(self.load_list) 89 | self.load_list = torch.gather(self.load_list, 0, sort_idx) # nbpe x N x 6 90 | 91 | self.load_list = self.load_list.long() 92 | 93 | # compute number of iterations needed 94 | spec_len_list = spec_len_list.view(num_batches_per_epoch, -1) 95 | spec_len_list, _ = spec_len_list.div(self.trunc_size).ceil().max(1) 96 | self.iter_per_epoch = int(spec_len_list.sum()) 97 | 98 | # set split cursor 99 | self.split_sizes = {'train': self.load_list.size(0), 'val': -1, 'test': -1} 100 | self.split_cursor = {'train': 0, 'val': 0, 'test': 0} 101 | 102 | 103 | def next_batch(self, split): 104 | T, idx = self.trunc_size, self.split_cursor[split] 105 | 106 | # seek and load data from raw files 107 | if self.is_subbatch_end: 108 | self.is_subbatch_end = False 109 | self.subbatch_cursor = 0 110 | 111 | if self.curr_split != split: 112 | self.curr_split = split 113 | if self.process is not None: 114 | self.process.terminate() 115 | self.process = Process(target=self.start_async_loader, args=(split, self.split_cursor[split])) 116 | self.process.start() 117 | 118 | self.len_text, self.len_wave, self.curr_text, self.curr_specM, self.curr_specL = self.queue.get() 119 | self.split_cursor[split] = (idx + 1) % self.split_sizes[split] 120 | self.subbatch_max_len = self.len_wave.max() 121 | 122 | # Variables to return 123 | # +1 to length of y to consider shifting for target y 124 | subbatch_len_text = [x for x in self.len_text] 125 | subbatch_len_wave = [min(x, T) for x in self.len_wave] 126 | x_text = self.curr_text 127 | y_specM = self.curr_specM[:, self.subbatch_cursor:self.subbatch_cursor + max(subbatch_len_wave) + 1].contiguous() 128 | y_specL = self.curr_specL[:, self.subbatch_cursor:self.subbatch_cursor + max(subbatch_len_wave) + 1].contiguous() 129 | 130 | if self.use_gpu: 131 | if self.pinned_memory: 132 | x_text = x_text.pin_memory() 133 | y_specM = y_specM.pin_memory() 134 | y_specL = y_specL.pin_memory() 135 | 136 | x_text = x_text.cuda() 137 | y_specM = y_specM.cuda() 138 | y_specL = y_specL.cuda() 139 | 140 | # Advance split_cursor or Move on to the next batch 141 | if self.subbatch_cursor + T < self.subbatch_max_len: 142 | self.subbatch_cursor = self.subbatch_cursor + T 143 | self.len_wave.sub_(T).clamp_(min=0) 144 | else: 145 | self.is_subbatch_end = True 146 | 147 | # Don't compute for empty batch elements 148 | if subbatch_len_wave.count(0) > 0: 149 | self.len_wave_mask = [idx for idx, l in enumerate(subbatch_len_wave) if l > 0] 150 | self.len_wave_mask = torch.LongTensor(self.len_wave_mask) 151 | if self.use_gpu: 152 | self.len_wave_mask = self.len_wave_mask.cuda() 153 | 154 | x_text = torch.index_select(x_text, 0, self.len_wave_mask) 155 | y_specM = torch.index_select(y_specM, 0, self.len_wave_mask) 156 | y_specL = torch.index_select(y_specL, 0, self.len_wave_mask) 157 | subbatch_len_text = [subbatch_len_text[idx] for idx in self.len_wave_mask] 158 | subbatch_len_wave = [subbatch_len_wave[idx] for idx in self.len_wave_mask] 159 | else: 160 | self.len_wave_mask = None 161 | 162 | return x_text, y_specM, y_specL, subbatch_len_wave, subbatch_len_text 163 | 164 | 165 | def start_async_loader(self, split, load_start_idx): 166 | # load batches to the queue asynchronously since it is a bottle-neck 167 | N, r = self.batch_size, self.r_factor 168 | load_curr_idx = load_start_idx 169 | 170 | while True: 171 | data_T, data_M, data_L, len_T, len_M = ([None for _ in range(N)] for _ in range(5)) 172 | # deploy workers to load data 173 | self.pool = Pool(self.n_workers) 174 | partial_func = partial(load_data_and_length, self.dir_bin, self.load_list[load_curr_idx]) 175 | results = self.pool.map_async(func=partial_func, iterable=range(N)) 176 | self.pool.close() 177 | self.pool.join() 178 | 179 | for result in results.get(): 180 | data_M[result[0]] = result[1] 181 | data_L[result[0]] = result[2] 182 | data_T[result[0]] = result[3] 183 | len_T[result[0]] = result[4] 184 | len_M[result[0]] = result[5] 185 | 186 | # TODO: output size is not accurate.. // 187 | len_text = torch.IntTensor(len_T) 188 | len_wave = torch.Tensor(len_M).div(r).ceil().mul(r).int() # consider r_factor 189 | curr_text = torch.LongTensor(N, len_text.max()).fill_(0) # null-padding at tail 190 | curr_specM = torch.Tensor(N, len_wave.max() + 1, self.dec_out_size).fill_(0) # null-padding at tail 191 | curr_specL = torch.Tensor(N, len_wave.max() + 1, self.post_out_size).fill_(0) # null-padding at tail 192 | 193 | # fill the template tensors 194 | for j in range(N): 195 | curr_text[j, 0:data_T[j].size(0)].copy_(data_T[j]) 196 | curr_specM[j, 0:data_M[j].size(0)].copy_(data_M[j]) 197 | curr_specL[j, 0:data_L[j].size(0)].copy_(data_L[j]) 198 | 199 | self.queue.put((len_text, len_wave, curr_text, curr_specM, curr_specL)) 200 | load_curr_idx = (load_curr_idx + 1) % self.split_sizes[split] 201 | 202 | 203 | def mask_prev_h(self, prev_h): 204 | if self.len_wave_mask is not None: 205 | if self.use_gpu: 206 | self.len_wave_mask = self.len_wave_mask.cuda() 207 | 208 | h_att, h_dec1, h_dec2 = prev_h 209 | h_att = torch.index_select(h_att.data, 1, self.len_wave_mask) # batch idx is 210 | h_dec1 = torch.index_select(h_dec1.data, 1, self.len_wave_mask) 211 | h_dec2 = torch.index_select(h_dec2.data, 1, self.len_wave_mask) 212 | prev_h = (Variable(h_att), Variable(h_dec1), Variable(h_dec2)) 213 | else: 214 | prev_h = prev_h 215 | 216 | return prev_h 217 | 218 | 219 | def get_num_vocab(self, vocab_file=None): 220 | if self.vocab_size: 221 | return self.vocab_size 222 | else: 223 | vocab_dict = torch.load(vocab_file) 224 | return len(vocab_dict) + 1 # +1 to consider null-padding 225 | 226 | 227 | def load_binary(file_path, offset, length): 228 | with open(file_path, 'rb') as datafile: 229 | datafile.seek(offset) 230 | line = datafile.read(length) 231 | obj = pickle.loads(line) 232 | return obj 233 | 234 | 235 | def load_data_and_length(dir_bin, load_info, load_idx): 236 | # If out of range error occurs at here, check whether you are using right text_limit. 237 | data_M = load_binary(dir_bin + 'specM.bin', load_info[load_idx][3], load_info[load_idx][6]) 238 | data_L = load_binary(dir_bin + 'specL.bin', load_info[load_idx][4], load_info[load_idx][7]) 239 | data_T = load_binary(dir_bin + 'text.bin', load_info[load_idx][5], load_info[load_idx][8]) 240 | # Convert to Tensor 241 | data_M = torch.from_numpy(data_M) 242 | data_L = torch.from_numpy(data_L) 243 | data_T = torch.LongTensor(data_T) 244 | 245 | len_M = data_M.size(0) 246 | len_T = data_T.size(0) 247 | return (load_idx, data_M, data_L, data_T, len_T, len_M) 248 | 249 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals, print_function, division 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.utils.rnn as rnn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | import numpy as np 10 | import random 11 | 12 | 13 | class Tacotron(nn.Module): 14 | def __init__(self, args): 15 | super(Tacotron, self).__init__() 16 | self.trunc_size = args.trunc_size 17 | self.r_factor = args.r_factor 18 | self.dec_out_size = args.dec_out_size 19 | 20 | self.teacher_forcing_ratio = args.teacher_forcing_ratio 21 | self.attn_weights = [] # only used in evaluation 22 | 23 | self.encoder = Encoder(args.vocab_size, args.charvec_dim, args.hidden_size, args.num_filters, args.dropout) 24 | self.linear_enc = nn.Linear(2 * args.hidden_size, 2 * args.hidden_size, bias=False) # N*T_enc x 2H 25 | 26 | self.decoder = AttnDecoderRNN(args.hidden_size, args.dec_out_size, args.r_factor, args.dropout) 27 | self.post_processor = PostProcessor(args.hidden_size, args.dec_out_size, args.post_out_size, args.num_filters // 2) 28 | 29 | 30 | def forward(self, enc_input, dec_input, wave_lengths, text_lengths, prev_h): 31 | r = self.r_factor 32 | T_wav, T_dec = max(wave_lengths), max(wave_lengths)//r 33 | 34 | enc_output = self.encoder(enc_input, text_lengths) 35 | in_attW_enc = rnn.pack_padded_sequence(enc_output, text_lengths, True) 36 | in_attW_enc = self.linear_enc(in_attW_enc.data) # N*T_enc x 2H 37 | 38 | output_mel_list = [] 39 | prev_dec_output = dec_input[:, 0] 40 | h_att, h_dec1, h_dec2 = prev_h 41 | 42 | for di in range(T_dec): 43 | start_idx, end_idx = di*r, (di+1)*r 44 | 45 | prev_dec_output, h_att, h_dec1, h_dec2 = self.decoder( 46 | enc_output, in_attW_enc, prev_dec_output, text_lengths, h_att, h_dec1, h_dec2) 47 | 48 | output_mel_list.append(prev_dec_output) 49 | 50 | if random.random() < self.teacher_forcing_ratio: 51 | prev_dec_output = dec_input[:, end_idx-1] # Teacher forcing 52 | else: 53 | prev_dec_output = output_mel_list[-1][:, -1] 54 | 55 | if not self.training: 56 | self.attn_weights.append(self.decoder.attn_weights.data) 57 | 58 | # TODO: make it stop when it meets EOS token 59 | 60 | output_mel = torch.cat(output_mel_list, dim=1) 61 | output_linear = self.post_processor(output_mel) 62 | last_h = (Variable(h_att.data), Variable(h_dec1.data), Variable(h_dec2.data)) 63 | return output_mel, output_linear, last_h 64 | 65 | 66 | class Encoder(nn.Module): 67 | """ input[0]: NxT sized Tensor 68 | input[1]: B sized lengths Tensor 69 | output: NxTxH sized Tensor 70 | """ 71 | def __init__(self, vocab_size, charvec_dim, hidden_size, num_filters, dropout_p=0.5): 72 | super(Encoder, self).__init__() 73 | self.hidden_size = hidden_size 74 | 75 | self.embedding = nn.Embedding(vocab_size, charvec_dim) 76 | self.prenet = nn.Sequential( 77 | nn.Linear(charvec_dim, 2 * hidden_size), 78 | nn.ReLU(), 79 | nn.Dropout(dropout_p), 80 | nn.Linear(2 * hidden_size, hidden_size), 81 | nn.ReLU(), 82 | nn.Dropout(dropout_p) 83 | ) 84 | self.CBHG = CBHG(hidden_size, hidden_size, hidden_size, hidden_size, hidden_size, num_filters, True) 85 | 86 | def forward(self, input, lengths): 87 | N, T = input.size(0), input.size(1) 88 | embedded = self.embedding(input).view(N*T, -1) # NT x C 89 | output = self.prenet(embedded).view(N, T, -1) # N x T x H 90 | output = self.CBHG(output, lengths) 91 | return output 92 | 93 | 94 | class AttnDecoderRNN(nn.Module): 95 | """ input_enc: Output from encoder (NxTx2H) 96 | input_attW_enc: masked-linear transformed input_enc 97 | input_dec: Output from previous-step decoder (NxO_dec) 98 | lengths: N sized Tensor 99 | output: N x r x H sized Tensor 100 | """ 101 | def __init__(self, hidden_size, output_size, r_factor=2, dropout_p=0.5): 102 | super(AttnDecoderRNN, self).__init__() 103 | self.r_factor = r_factor 104 | 105 | self.prenet = nn.Sequential( 106 | nn.Linear(output_size, 2 * hidden_size), 107 | nn.ReLU(), 108 | nn.Dropout(dropout_p), 109 | nn.Linear(2 * hidden_size, hidden_size), 110 | nn.ReLU(), 111 | nn.Dropout(dropout_p) 112 | ) 113 | self.linear_dec = nn.Linear(2 * hidden_size, 2 * hidden_size) 114 | self.gru_att = nn.GRU(hidden_size, 2 * hidden_size, batch_first=True) 115 | 116 | self.attn = nn.Linear(2 * hidden_size, 1) # TODO: change name... 117 | 118 | self.short_cut = nn.Linear(4 * hidden_size, 2 * hidden_size) 119 | self.gru_dec1 = nn.GRU(4 * hidden_size, 2 * hidden_size, num_layers=1, batch_first=True) 120 | self.gru_dec2 = nn.GRU(2 * hidden_size, 2 * hidden_size, num_layers=1, batch_first=True) 121 | 122 | self.out = nn.Linear(2 * hidden_size, r_factor * output_size) 123 | 124 | def forward(self, input_enc, input_attW_enc, input_dec, lengths_enc, hidden_att=None, hidden_dec1=None, hidden_dec2=None): 125 | N = input_dec.size(0) 126 | 127 | out_att = self.prenet(input_dec).unsqueeze(1) # N x O_dec -> N x 1 x H 128 | out_att, hidden_att = self.gru_att(out_att, hidden_att) # N x 1 x 2H 129 | in_attW_dec = self.linear_dec(out_att.squeeze(1)).unsqueeze(1).expand_as(input_enc) 130 | in_attW_dec = rnn.pack_padded_sequence(in_attW_dec, lengths_enc, True) # N*T_enc x 2H 131 | 132 | self.attn_weights = torch.add(input_attW_enc, in_attW_dec.data).tanh() # N x T_enc x 2H 133 | self.attn_weights = self.attn(self.attn_weights).exp() # N*T_enc x 1 134 | self.attn_weights = rnn.PackedSequence(self.attn_weights, in_attW_dec.batch_sizes) 135 | self.attn_weights, _ = rnn.pad_packed_sequence(self.attn_weights, True) 136 | self.attn_weights = F.normalize(self.attn_weights, 1, 1) # N x T_enc x 1 137 | 138 | attn_applied = torch.bmm(self.attn_weights.transpose(1,2), input_enc) # N x 1 x 2H 139 | 140 | out_dec = torch.cat((attn_applied, out_att), 2) # N x 1 x 4H 141 | residual = self.short_cut(out_dec.squeeze(1)).unsqueeze(1) # N x 1 x 2H 142 | 143 | out_dec, hidden_dec1 = self.gru_dec1(out_dec, hidden_dec1) 144 | residual = residual + out_dec 145 | 146 | out_dec, hidden_dec2 = self.gru_dec2(residual, hidden_dec2) 147 | residual = residual + out_dec 148 | 149 | output = self.out(residual.squeeze(1)).view(N, self.r_factor, -1) 150 | return output, hidden_att, hidden_dec1, hidden_dec2 151 | 152 | 153 | class PostProcessor(nn.Module): 154 | """ input: N x T x O_dec 155 | output: N x T x O_post 156 | """ 157 | def __init__(self, hidden_size, dec_out_size, post_out_size, num_filters): 158 | super(PostProcessor, self).__init__() 159 | self.CBHG = CBHG(dec_out_size, hidden_size, 2 * hidden_size, hidden_size, hidden_size, num_filters, True) 160 | self.projection = nn.Linear(2 * hidden_size, post_out_size) 161 | 162 | def forward(self, input, lengths=None): 163 | if lengths is None: 164 | N, T = input.size(0), input.size(1) 165 | lengths = [T for _ in range(N)] 166 | output = self.CBHG(input, lengths).contiguous().view(N*T,-1) 167 | output = self.projection(output).view(N,T,-1) 168 | else: 169 | output = self.CBHG(input, lengths) 170 | output = rnn.pack_padded_sequence(output, lengths, True) 171 | output = rnn.PackedSequence(self.projection(output.data), output.batch_sizes) 172 | output, _ = rnn.pad_packed_sequence(output, True) 173 | return output 174 | 175 | 176 | class CBHG(nn.Module): 177 | """ input: NxTxinput_dim sized Tensor 178 | output: NxTx2gru_dim sized Tensor 179 | """ 180 | def __init__(self, input_dim, conv_bank_dim, conv_dim1, conv_dim2, gru_dim, num_filters, is_masked): 181 | super(CBHG, self).__init__() 182 | self.num_filters = num_filters 183 | 184 | bank_out_dim = num_filters * conv_bank_dim 185 | self.conv_bank = nn.ModuleList() 186 | for i in range(num_filters): 187 | self.conv_bank.append(nn.Conv1d(input_dim, conv_bank_dim, i + 1, stride=1, padding=int(np.ceil(i / 2)))) 188 | 189 | # define batch normalization layer, we use BN1D since the sequence length is not fixed 190 | self.bn_list = nn.ModuleList() 191 | self.bn_list.append(nn.BatchNorm1d(bank_out_dim)) 192 | self.bn_list.append(nn.BatchNorm1d(conv_dim1)) 193 | self.bn_list.append(nn.BatchNorm1d(conv_dim2)) 194 | 195 | self.conv1 = nn.Conv1d(bank_out_dim, conv_dim1, 3, stride=1, padding=1) 196 | self.conv2 = nn.Conv1d(conv_dim1, conv_dim2, 3, stride=1, padding=1) 197 | 198 | if input_dim != conv_dim2: 199 | self.residual_proj = nn.Linear(input_dim, conv_dim2) 200 | 201 | self.highway = Highway(conv_dim2, 4) 202 | self.BGRU = nn.GRU(input_size=conv_dim2, hidden_size=gru_dim, num_layers=1, batch_first=True, bidirectional=True) 203 | 204 | def forward(self, input, lengths): 205 | N, T = input.size(0), input.size(1) 206 | 207 | conv_bank_out = [] 208 | input_t = input.transpose(1, 2) # NxTxH -> NxHxT 209 | for i in range(self.num_filters): 210 | tmp_input = input_t 211 | if i % 2 == 0: 212 | tmp_input = tmp_input.unsqueeze(-1) 213 | tmp_input = F.pad(tmp_input, (0,0,0,1)).squeeze(-1) # NxHxT 214 | conv_bank_out.append(self.conv_bank[i](tmp_input)) 215 | 216 | residual = torch.cat(conv_bank_out, dim=1) # NxHFxT 217 | residual = F.relu(self.bn_list[0](residual)) 218 | residual = F.max_pool1d(residual, 2, stride=1) 219 | residual = self.conv1(residual) # NxHxT 220 | residual = F.relu(self.bn_list[1](residual)) 221 | residual = self.conv2(residual) # NxHxT 222 | residual = self.bn_list[2](residual).transpose(1,2) # NxHxT -> NxTxH 223 | 224 | rnn_input = input 225 | if rnn_input.size() != residual.size(): 226 | rnn_input = self.residual_proj(rnn_input) 227 | rnn_input = rnn_input + residual 228 | rnn_input = self.highway(rnn_input).view(N, T, -1) 229 | 230 | output = rnn.pack_padded_sequence(rnn_input, lengths, True) 231 | output, _ = self.BGRU(output) # zero h_0 is used by default 232 | output, _ = rnn.pad_packed_sequence(output, True) # NxTx2H 233 | return output 234 | 235 | 236 | class Highway(nn.Module): 237 | """ 238 | Code from: https://github.com/kefirski/pytorch_Highway 239 | """ 240 | def __init__(self, size, num_layers, f=F.relu): 241 | super(Highway, self).__init__() 242 | self.num_layers = num_layers 243 | self.nonlinear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)]) 244 | self.linear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)]) 245 | self.gate = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)]) 246 | self.f = f 247 | 248 | def forward(self, x): 249 | """ input: NxH sized Tensor 250 | output: NxH sized Tensor 251 | """ 252 | for layer in range(self.num_layers): 253 | gate = F.sigmoid(self.gate[layer](x)) 254 | nonlinear = self.f(self.nonlinear[layer](x)) 255 | linear = self.linear[layer](x) 256 | x = gate * nonlinear + (1 - gate) * linear 257 | return x 258 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals, print_function, division 3 | from multiprocessing import Process, Queue 4 | import os, re, librosa, argparse, torch, pickle, multiprocessing 5 | import numpy as np 6 | 7 | def get_vocab(dataset, directory): 8 | """ read files & create characters dictionary 9 | Decide what characters to remove at this stage 10 | """ 11 | if dataset == 'blizzard': 12 | preprocess_text_blizzard(directory, 'prompts.gui') 13 | elif dataset == 'nancy': 14 | preprocess_text_nancy(directory, 'prompts.data') 15 | elif dataset == 'etri': 16 | preprocess_text_etri(directory, 'prompts.data') 17 | 18 | vocabs = set([]) 19 | 20 | for root, dirnames, filenames in os.walk(directory): 21 | for filename in filenames: 22 | if filename[-4:] == '.txt': 23 | path = os.path.join(root, filename) 24 | with open(path, 'r') as rFile: 25 | line = rFile.readline() 26 | while line: 27 | line = line.strip() 28 | for i in range(len(line)): 29 | vocabs.add(line[i]) 30 | line = rFile.readline() 31 | vocabs = sorted(list(vocabs)) 32 | print(vocabs) 33 | 34 | 35 | def cleanse(directory, dir_bin, write_csv, spec_file_list): 36 | """ read files & cleanse each line 37 | look at cleansed file and check what to cleanse more 38 | """ 39 | text_file_list = [] 40 | cleansed_files = [] 41 | vocabs = set([]) 42 | vocabs.add('Z') # there is no 'Z' in blizzard corpus (may cause problem later) 43 | 44 | dirts = ['`', '#', '@', '\|'] 45 | dirts = '(' + '|'.join(dirts) + ')' 46 | reg_dirts = re.compile(dirts) 47 | reg_spaces = re.compile(r'\s+') 48 | reg_spacedSymbols = re.compile(r' (?P\W)') 49 | 50 | for root, dirnames, filenames in os.walk(directory): 51 | wFileName = root[len(directory):] 52 | writePath = dir_bin + wFileName + '.txt' 53 | with open(writePath, 'w') as wFile: 54 | for filename in sorted(filenames): 55 | if filename[-4:] == '.txt': 56 | readPath = os.path.join(root, filename) 57 | 58 | with open(readPath, 'r') as rFile: 59 | line = rFile.readline() 60 | while line and len(line) > 0: 61 | line = line.strip() 62 | line = reg_dirts.sub('', line) 63 | line = reg_spaces.sub(' ', line) 64 | line = reg_spacedSymbols.sub(r'\g', line) 65 | 66 | wFile.write(line+'\n') 67 | 68 | if write_csv: 69 | for i in range(len(line)): 70 | vocabs.add(line[i]) 71 | 72 | line = rFile.readline() 73 | if write_csv: 74 | cleansed_files.append((writePath, wFileName)) 75 | 76 | if write_csv: 77 | print('Start to write text binary files') 78 | if spec_file_list is None: 79 | print('spec_file_list is not found.') 80 | 81 | tmp_vocab = sorted(list(vocabs)) 82 | vocab_dict = {} 83 | for i, vocab in enumerate(tmp_vocab): 84 | vocab_dict[vocab] = i + 1 # zero will be used for null-padding 85 | torch.save(vocab_dict, dir_bin + 'vocab.t7') 86 | 87 | write_path = dir_bin + 'text.bin' 88 | offset = 0 89 | with open(write_path, 'wb') as w_file: 90 | for cleansed_tuple in cleansed_files: 91 | cleansed_file = cleansed_tuple[0] 92 | 93 | if spec_file_list is not None: 94 | with open(cleansed_file, 'r') as rFile: 95 | count = 0 96 | line = rFile.readline().strip() 97 | while line: 98 | # print(line) 99 | line = [vocab_dict[y] for y in line] 100 | 101 | binary_text = pickle.dumps(line, protocol=pickle.HIGHEST_PROTOCOL) 102 | w_file.write(binary_text) 103 | file_id = spec_file_list[count][0] 104 | text_file_list.append((file_id, len(line), offset, len(binary_text))) 105 | offset += len(binary_text) 106 | 107 | line = rFile.readline().strip() 108 | count += 1 109 | else: 110 | wFileName = cleansed_tuple[1] 111 | writePath = dir_bin + wFileName + '_text.csv' 112 | 113 | with open(writePath, 'w') as wFile: 114 | with open(cleansed_file, 'r') as rFile: 115 | line = rFile.readline().strip() 116 | while line: 117 | # print(line) 118 | line = [str(vocab_dict[y]) for y in line] 119 | wFile.write(','.join(line)+'\n') 120 | line = rFile.readline().strip() 121 | 122 | return text_file_list 123 | 124 | 125 | def preprocess_text_blizzard(directory, file): 126 | readPath = directory + file 127 | writePath = directory + '/prompts.txt' 128 | 129 | with open(writePath, 'w') as wFile: 130 | with open(readPath, 'r') as rFile: 131 | line = rFile.readline() # 1st line 132 | while line and len(line) > 0: 133 | line = rFile.readline() # 2nd line (txt included) 134 | wFile.write(line) 135 | 136 | rFile.readline() # 3rd line 137 | line = rFile.readline() # 1rd line of next wav file 138 | 139 | 140 | def preprocess_text_nancy(directory, file): 141 | readPath = directory + file 142 | writePath = directory + '/prompts.txt' 143 | 144 | with open(writePath, 'w') as wFile: 145 | with open(readPath, 'r') as rFile: 146 | line = rFile.readline() 147 | while line and len(line) > 0: 148 | wFile.write(line[line.find('"')+1:-4].strip()+'\n') 149 | line = rFile.readline() 150 | 151 | 152 | def preprocess_text_etri(directory, file): 153 | """ 154 | Code from: https://github.com/neotune/python-korean-handler 155 | """ 156 | readPath = directory + '/prompts.data' 157 | writePath = directory + '/prompts.txt' 158 | 159 | # 유니코드 한글 시작 : 44032, 끝 : 55199 160 | Start_Code, ChoSung, JungSung = 44032, 588, 28 161 | 162 | # 초성 리스트. 00 ~ 18 163 | ChoSung_LIST = ['ㄱ', 'ㄲ', 'ㄴ', 'ㄷ', 'ㄸ', 'ㄹ', 'ㅁ', 'ㅂ', 'ㅃ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅉ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ'] 164 | 165 | # 중성 리스트. 00 ~ 20 166 | JungSung_LIST = ['ㅏ', 'ㅐ', 'ㅑ', 'ㅒ', 'ㅓ', 'ㅔ', 'ㅕ', 'ㅖ', 'ㅗ', 'ㅘ', 'ㅙ', 'ㅚ', 'ㅛ', 'ㅜ', 'ㅝ', 'ㅞ', 'ㅟ', 'ㅠ', 'ㅡ', 'ㅢ', 167 | 'ㅣ'] 168 | 169 | # 종성 리스트. 00 ~ 27 + 1(1개 없음) 170 | JongSung_LIST = ['', 'ㄱ', 'ㄲ', 'ㄳ', 'ㄴ', 'ㄵ', 'ㄶ', 'ㄷ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㄼ', 'ㄽ', 'ㄾ', 'ㄿ', 'ㅀ', 'ㅁ', 'ㅂ', 'ㅄ', 'ㅅ', 171 | 'ㅆ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ'] 172 | 173 | with open(writePath, 'w') as wFile: 174 | with open(readPath, 'r', encoding="utf-8") as rFile: 175 | line = rFile.readline() # skip this line (utf8 header) 176 | line = rFile.readline() 177 | while line: 178 | line_dec = "" 179 | line = line.strip().split('\t')[1] 180 | line = list(line) 181 | 182 | for keyword in line: 183 | # 한글 여부 check 후 분리: ㄱ~ㅎ + ㅏ~ㅣ+ 가~힣 184 | if re.match('.*[ㄱ-ㅎㅏ-ㅣ가-힣]+.*', keyword) is not None: 185 | char_code = ord(keyword) - Start_Code 186 | char1 = int(char_code / ChoSung) 187 | line_dec += ChoSung_LIST[char1] 188 | #print('초성 : {}'.format(CHOSUNG_LIST[char1])) 189 | char2 = int((char_code - (ChoSung * char1)) / JungSung) 190 | line_dec += JungSung_LIST[char2] 191 | #print('중성 : {}'.format(JUNGSUNG_LIST[char2])) 192 | char3 = int((char_code - (ChoSung * char1) - (JungSung * char2))) 193 | line_dec += JongSung_LIST[char3] 194 | #print('종성 : {}'.format(JONGSUNG_LIST[char3])) 195 | else: 196 | line_dec += keyword 197 | 198 | wFile.write(line_dec + '\n') 199 | line = rFile.readline() 200 | 201 | 202 | def trim_silence(audio, sr, frame_shift_inMS, file_name=None, beginning_buffer=2, ending_buffer=5): 203 | # # buffers are counted in number of frame shifts 204 | # onset = librosa.onset.onset_detect(audio, sr=sr) 205 | # if not onset: 206 | # print('maybe empty file:', file_name) 207 | # onset_sample = librosa.frames_to_samples(onset) 208 | # 209 | # unit = sr * frame_shift_inMS / 1000 210 | # start_idx = onset_sample[0] - beginning_buffer * unit 211 | # if len(onset_sample) == 1: 212 | # end_idx = -1 213 | # else: 214 | # end_idx = onset_sample[-1] + ending_buffer * unit 215 | # 216 | # return audio[start_idx:end_idx] 217 | return audio 218 | 219 | 220 | def preprocess_spec(dataset, f_type, rDirectory, dir_bin, q): 221 | if dataset == 'vctk': 222 | silence_threshold = 0.005 223 | sample_rate = 24000 224 | elif dataset == 'blizzard': 225 | silence_threshold = 0.005 226 | sample_rate = 16000 227 | elif dataset == '10': 228 | silence_threshold = 0.005 229 | sample_rate = 16000 230 | elif dataset == 'nancy': 231 | silence_threshold = 0.005 232 | sample_rate = 16000 233 | elif dataset == 'etri': 234 | silence_threshold = 0.005 235 | sample_rate = 16000 236 | 237 | frame_len_inMS = 50 238 | frame_shift_inMS = 12.5 239 | isMono = True 240 | type_filter = f_type 241 | 242 | # params for stft 243 | n_fft = 2048 244 | window_len = int(np.ceil(frame_len_inMS * sample_rate / 1000)) 245 | hop_length = int(np.ceil(frame_shift_inMS * sample_rate / 1000)) 246 | 247 | # params for mel-filter 248 | mel_dim = 80 249 | mel_basis = librosa.filters.mel(sample_rate, n_fft, n_mels=mel_dim) 250 | 251 | # params for normalization 252 | ref_level_db = 20 253 | min_level_db = -100 254 | 255 | 256 | files = [] 257 | count = 0 258 | print('Check files..') 259 | for root, dirnames, filenames in os.walk(rDirectory): 260 | for filename in sorted(filenames): 261 | if filename[-4:] == '.wav': 262 | path = os.path.join(root, filename) 263 | audio,_ = librosa.load(path, sr=sample_rate, mono=isMono) 264 | audio = trim_silence(audio, sample_rate, frame_shift_inMS, filename=filename) 265 | length = len(audio) 266 | files.append((path, str(count), int(length))) 267 | count += 1 268 | # librosa.output.write_wav(path, audio, 16000) 269 | 270 | spec_max, spec_min = None, None 271 | 272 | if type_filter == 'linear': 273 | write_path = dir_bin + 'specL.bin' 274 | elif type_filter == 'mel': 275 | write_path = dir_bin + 'specM.bin' 276 | else: 277 | write_path = None 278 | 279 | print('Start writing %s spectrogram binary files' % f_type) 280 | spec_list = [] 281 | offset = 0 282 | with open(write_path, 'wb') as w_file: 283 | for item in files: 284 | path = item[0] 285 | line_idx = item[1] 286 | 287 | audio,_ = librosa.load(path, sr=sample_rate, mono=isMono) 288 | audio = trim_silence(audio, sample_rate, frame_shift_inMS) 289 | 290 | D = librosa.stft(audio, n_fft=n_fft, win_length=window_len, window='hann', hop_length=hop_length) 291 | spec = np.abs(D) 292 | 293 | if type_filter == 'mel': 294 | # mel-scale spectrogram generation 295 | spec = np.dot(mel_basis, spec) 296 | spec = 20 * np.log10(np.maximum(1e-5, spec)) 297 | elif type_filter == 'linear': 298 | # linear spectrogram generation 299 | spec = 20 * np.log10(np.maximum(1e-5, spec)) - ref_level_db 300 | 301 | # normalize 302 | spec = np.clip(-(spec - min_level_db) / min_level_db, 0, 1) 303 | spec = spec.T # H x T -> T x H 304 | 305 | # write to file 306 | binary_spec = pickle.dumps(spec, protocol=pickle.HIGHEST_PROTOCOL) 307 | w_file.write(binary_spec) 308 | spec_list.append((line_idx, len(spec), offset, len(binary_spec))) 309 | offset += len(binary_spec) 310 | 311 | if not spec_max or spec_max < spec.max(): 312 | spec_max = spec.max() 313 | 314 | if not spec_min or spec_min < spec.min(): 315 | spec_min = spec.min() 316 | 317 | print(f_type, 'spectrogram max/min', spec_max, spec_min) 318 | 319 | q.put((f_type, spec_list)) 320 | 321 | 322 | if __name__ == '__main__': 323 | try: 324 | parser = argparse.ArgumentParser(description='training script') 325 | parser.add_argument('--dataset', type=str, default='10', help='vctk / blizzard / 10 / nancy / etri') 326 | args = parser.parse_args() 327 | dataset = args.dataset 328 | print('Dataset to preprocess:', dataset) 329 | write_csv = True 330 | 331 | if dataset == 'blizzard': 332 | dir_text = '/data2/lyg0722/TTS_corpus/blizzard/segmented/txt/' 333 | dir_spec = '/data2/lyg0722/TTS_corpus/blizzard/segmented/wav/' 334 | dir_bin = '/data2/lyg0722/TTS_corpus/blizzard/segmented/bin/' 335 | elif dataset == 'etri': 336 | dir_text = '/data2/lyg0722/TTS_corpus/etri/txt/' 337 | dir_spec = '/data2/lyg0722/TTS_corpus/etri/wav/' 338 | dir_bin = '/data2/lyg0722/TTS_corpus/etri/bin/' 339 | 340 | q = Queue() 341 | p_lin = Process(target=preprocess_spec, args=(dataset, 'linear', dir_spec, dir_bin, q)) 342 | p_mel = Process(target=preprocess_spec, args=(dataset, 'mel', dir_spec, dir_bin, q)) 343 | p_lin.daemon = True 344 | p_mel.daemon = True 345 | p_lin.start() 346 | p_mel.start() 347 | 348 | lin_list = None 349 | tmp_get = q.get() 350 | 351 | if tmp_get[0] == 'mel': 352 | mel_list = tmp_get[1] 353 | else: 354 | lin_list = tmp_get[1] 355 | mel_list = q.get()[1] 356 | 357 | # text part 358 | get_vocab(dataset, dir_text) 359 | txt_list = cleanse(dir_text, dir_bin, write_csv, mel_list) 360 | 361 | if not lin_list: 362 | lin_list = q.get()[1] 363 | 364 | p_lin.join() 365 | p_mel.join() 366 | 367 | # make file load list 368 | assert len(txt_list) == len(mel_list) 369 | line_load_list = [] 370 | for i, item in enumerate(mel_list): 371 | assert item[0] == txt_list[i][0] and item[0] == lin_list[i][0] 372 | line_idx = item[0] 373 | wave_length = item[1] 374 | text_length = txt_list[i][1] 375 | offset_M = item[2] 376 | offset_L = lin_list[i][2] 377 | offset_T = txt_list[i][2] 378 | len_M = item[3] 379 | len_L = lin_list[i][3] 380 | len_T = txt_list[i][3] 381 | line_load_list.append((i, wave_length, text_length, offset_M, offset_L, offset_T, len_M, len_L, len_T)) 382 | 383 | torch.save(torch.DoubleTensor(line_load_list), dir_bin + 'line_load_list.t7') 384 | finally: 385 | for p in multiprocessing.active_children(): 386 | p.terminate() 387 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals, print_function, division 3 | import argparse, multiprocessing 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | from torch import optim 9 | 10 | import matplotlib 11 | matplotlib.use('Agg') 12 | import matplotlib.pyplot as plt 13 | import matplotlib.ticker as ticker 14 | 15 | from model import Tacotron as Tacotron 16 | from loader import DataLoader 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser(description='training script') 21 | # data load 22 | parser.add_argument('--data', type=str, default='blizzard', help='blizzard / nancy') 23 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 24 | parser.add_argument('--text_limit', type=int, default=1000, help='maximum length of text to include in training set') 25 | parser.add_argument('--wave_limit', type=int, default=1400, help='maximum length of spectrogram to include in training set') 26 | parser.add_argument('--trunc_size', type=int, default=700, help='used for truncated-BPTT when memory is not enough.') 27 | parser.add_argument('--shuffle_data', type=int, default=1, help='whether to shuffle data loader') 28 | parser.add_argument('--load_queue_size', type=int, default=8, help='maximum number of batches to load on the memory') 29 | parser.add_argument('--n_workers', type=int, default=2, help='number of workers used in data loader') 30 | # model 31 | parser.add_argument('--charvec_dim', type=int, default=256, help='') 32 | parser.add_argument('--hidden_size', type=int, default=128, help='') 33 | parser.add_argument('--dec_out_size', type=int, default=80, help='decoder output size') 34 | parser.add_argument('--post_out_size', type=int, default=1025, help='should be n_fft / 2 + 1(check n_fft from "input_specL" ') 35 | parser.add_argument('--num_filters', type=int, default=16, help='number of filters in filter bank of CBHG') 36 | parser.add_argument('--r_factor', type=int, default=5, help='reduction factor(# of multiple output)') 37 | parser.add_argument('--dropout', type=float, default=0.5, help='') 38 | # optimization 39 | parser.add_argument('--max_epochs', type=int, default=100000, help='maximum epoch to train') 40 | parser.add_argument('--grad_clip', type=float, default=1, help='gradient clipping') 41 | parser.add_argument('--learning_rate', type=float, default=1e-3, help='2e-3 from Ito, I used to use 5e-4') 42 | parser.add_argument('--lr_decay_every', type=int, default=25000, help='decay learning rate every...') 43 | parser.add_argument('--lr_decay_factor', type=float, default=0.5, help='decay learning rate by this factor') 44 | parser.add_argument('--teacher_forcing_ratio', type=float, default=1, help='value between 0~1, use this for scheduled sampling') 45 | # loading 46 | parser.add_argument('--init_from', type=str, default='', help='load parameters from...') 47 | parser.add_argument('--resume', type=int, default=0, help='1 for resume from saved epoch') 48 | # misc 49 | parser.add_argument('--exp_no', type=int, default=0, help='') 50 | parser.add_argument('--print_every', type=int, default=-1, help='') 51 | parser.add_argument('--plot_every', type=int, default=-1, help='') 52 | parser.add_argument('--save_every', type=int, default=-1, help='') 53 | parser.add_argument('--save_dir', type=str, default='checkpoint', help='') 54 | parser.add_argument('--pinned_memory', type=int, default=1, help='1 to use pinned memory') 55 | parser.add_argument('--gpu', type=int, nargs='+', help='index of gpu machines to run') 56 | # debug 57 | parser.add_argument('--debug', type=int, default=0, help='1 for debug mode') 58 | args = parser.parse_args() 59 | 60 | torch.manual_seed(0) 61 | 62 | # set dataset option 63 | if args.data == 'blizzard': 64 | args.dir_bin = '/home/lyg0722/TTS_corpus/blizzard/segmented/bin/' 65 | elif args.data == 'etri': 66 | args.dir_bin = '/data2/lyg0722/TTS_corpus/etri/bin/' 67 | else: 68 | print('no dataset') 69 | return 70 | 71 | if args.gpu is None: 72 | args.use_gpu = False 73 | args.gpu = [] 74 | else: 75 | args.use_gpu = True 76 | torch.cuda.manual_seed(0) 77 | torch.cuda.set_device(args.gpu[0]) 78 | 79 | loader = DataLoader(args) 80 | 81 | # set misc options 82 | args.vocab_size = loader.get_num_vocab() 83 | if args.print_every == -1: 84 | args.print_every = loader.iter_per_epoch 85 | if args.plot_every == -1: 86 | args.plot_every = args.print_every 87 | if args.save_every == -1: 88 | args.save_every = loader.iter_per_epoch * 10 # save every 10 epoch by default 89 | 90 | model = Tacotron(args) 91 | model_optim = optim.Adam(model.parameters(), args.learning_rate) 92 | criterion_mel = nn.L1Loss(size_average=False) 93 | criterion_lin = nn.L1Loss(size_average=False) 94 | 95 | start = time.time() 96 | plot_losses = [] 97 | print_loss_total = 0 # Reset every print_every 98 | plot_loss_total = 0 # Reset every plot_every 99 | start_epoch = 0 100 | iter = 1 101 | 102 | if args.init_from: 103 | checkpoint = torch.load(args.init_from, map_location=lambda storage, loc: storage) 104 | model.load_state_dict(checkpoint['state_dict']) 105 | if args.resume != 0: 106 | start_epoch = checkpoint['epoch'] 107 | plot_losses = checkpoint['plot_losses'] 108 | print('loaded checkpoint %s (epoch %d)' % (args.init_from, start_epoch)) 109 | 110 | model = model.train() 111 | if args.use_gpu: 112 | model = model.cuda() 113 | criterion_mel = criterion_mel.cuda() 114 | criterion_lin = criterion_lin.cuda() 115 | 116 | print('Start training... (1 epoch = %s iters)' % (loader.iter_per_epoch)) 117 | while iter < args.max_epochs * loader.iter_per_epoch + 1: 118 | if loader.is_subbatch_end: 119 | prev_h = (None, None, None) # set prev_h = h_0 when new sentences are loaded 120 | enc_input, target_mel, target_lin, wave_lengths, text_lengths = loader.next_batch('train') 121 | 122 | max_wave_len = max(wave_lengths) 123 | 124 | enc_input = Variable(enc_input, requires_grad=False) 125 | target_mel = Variable(target_mel, requires_grad=False) 126 | target_lin = Variable(target_lin, requires_grad=False) 127 | 128 | prev_h = loader.mask_prev_h(prev_h) 129 | 130 | model_optim.zero_grad() 131 | pred_mel, pred_lin, prev_h = model(enc_input, target_mel[:, :-1], wave_lengths, text_lengths, prev_h) 132 | 133 | loss_mel = criterion_mel(pred_mel, target_mel[:, 1:])\ 134 | .div(max_wave_len * args.batch_size * args.dec_out_size) 135 | loss_linear = criterion_lin(pred_lin, target_lin[:, 1:])\ 136 | .div(max_wave_len * args.batch_size * args.post_out_size) 137 | loss = torch.sum(loss_mel + loss_linear) 138 | 139 | loss.backward() 140 | nn.utils.clip_grad_norm(model.parameters(), args.grad_clip) # gradient clipping 141 | model_optim.step() 142 | 143 | print_loss_total += loss.data[0] 144 | plot_loss_total += loss.data[0] 145 | 146 | if iter % args.print_every == 0: 147 | print_loss_avg = print_loss_total / args.print_every 148 | print_loss_total = 0 149 | print('%s (%d %d%%) %.4f' % (timeSince(start, iter / args.max_epochs), 150 | iter, iter / args.max_epochs * 100, print_loss_avg)) 151 | if iter % args.plot_every == 0: 152 | plot_loss_avg = plot_loss_total / args.plot_every 153 | plot_losses.append(plot_loss_avg) 154 | plot_loss_total = 0 155 | 156 | save_name = '%s/%dth_exp_loss.png' % (args.save_dir, args.exp_no) 157 | savePlot(plot_losses, save_name) 158 | 159 | 160 | if iter % args.save_every == 0: 161 | epoch = start_epoch + iter // loader.iter_per_epoch 162 | save_name = '%s/%d_%dth.t7' % (args.save_dir, args.exp_no, epoch) 163 | state = { 164 | 'epoch': epoch, 165 | 'args': args, 166 | 'state_dict': model.state_dict(), 167 | 'optimizer': model_optim.state_dict(), 168 | 'plot_losses': plot_losses 169 | } 170 | torch.save(state, save_name) 171 | print('model saved to', save_name) 172 | # if is_best: # TODO: implement saving best model. 173 | # shutil.copyfile(save_name, '%s/%d_best.t7' % (args.save_dir, args.exp_no)) 174 | 175 | iter += 1 176 | 177 | 178 | ###################################################################### 179 | # Plotting results 180 | # ---------------- 181 | # 182 | # Plotting is done with matplotlib, using the array of loss values 183 | # ``plot_losses`` saved while training. 184 | # 185 | 186 | def savePlot(points, outpath): 187 | plt.figure() 188 | fig, ax = plt.subplots() 189 | # this locator puts ticks at regular intervals 190 | loc = ticker.MultipleLocator(base=0.2) 191 | ax.yaxis.set_major_locator(loc) 192 | plt.plot(points) 193 | plt.savefig(outpath) 194 | plt.close('all') 195 | 196 | ###################################################################### 197 | # This is a helper function to print time elapsed and estimated time 198 | # remaining given the current time and progress %. 199 | # 200 | 201 | import time, math 202 | def asMinutes(s): 203 | m = math.floor(s / 60) 204 | s -= m * 60 205 | return '%dm %ds' % (m, s) 206 | 207 | 208 | def timeSince(since, percent): 209 | now = time.time() 210 | s = now - since 211 | es = s / (percent) 212 | rs = es - s 213 | return '%s (- %s)' % (asMinutes(s), asMinutes(rs)) 214 | 215 | 216 | if __name__ == '__main__': 217 | try: 218 | main() 219 | finally: 220 | for p in multiprocessing.active_children(): 221 | # p.join() 222 | p.terminate() 223 | -------------------------------------------------------------------------------- /trimmer.py: -------------------------------------------------------------------------------- 1 | import webrtcvad, os, wave, contextlib, collections, argparse 2 | 3 | def read_wave(path): 4 | with contextlib.closing(wave.open(path, 'rb')) as wf: 5 | num_channels = wf.getnchannels() 6 | assert num_channels == 1 7 | sample_width = wf.getsampwidth() 8 | assert sample_width == 2 9 | sample_rate = wf.getframerate() 10 | assert sample_rate in (8000, 16000, 32000) 11 | pcm_data = wf.readframes(wf.getnframes()) 12 | return pcm_data, sample_rate 13 | 14 | 15 | def write_wave(path, audio, sample_rate): 16 | with contextlib.closing(wave.open(path, 'wb')) as wf: 17 | wf.setnchannels(1) 18 | wf.setsampwidth(2) 19 | wf.setframerate(sample_rate) 20 | wf.writeframes(audio) 21 | 22 | 23 | class Frame(object): 24 | def __init__(self, bytes, timestamp, duration): 25 | self.bytes = bytes 26 | self.timestamp = timestamp 27 | self.duration = duration 28 | 29 | 30 | def frame_generator(frame_duration_ms, audio, sample_rate): 31 | n = int(sample_rate * (frame_duration_ms / 1000.0) * 2) 32 | offset = 0 33 | timestamp = 0.0 34 | duration = (float(n) / sample_rate) / 2.0 35 | output = [] 36 | while offset + n < len(audio): 37 | output.append(Frame(audio[offset:offset + n], timestamp, duration)) 38 | timestamp += duration 39 | offset += n 40 | return output 41 | 42 | 43 | def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, 44 | aggressiveness, buffer_ratio, frames, name): 45 | num_padding_frames = int(padding_duration_ms / frame_duration_ms) 46 | ring_buffer = collections.deque(maxlen=num_padding_frames) 47 | triggered = False 48 | voiced_frames = [] 49 | vad = webrtcvad.Vad(aggressiveness) 50 | count_until_triggered = 0 51 | count_unvoiced_tail = 0 52 | for frame in frames: 53 | if not triggered: 54 | ring_buffer.append(frame) 55 | num_voiced = 0 56 | for f in ring_buffer: 57 | if vad.is_speech(f.bytes, sample_rate): 58 | num_voiced += 1 59 | if num_voiced > buffer_ratio * ring_buffer.maxlen: 60 | triggered = True 61 | voiced_frames.extend(ring_buffer) 62 | ring_buffer.clear() 63 | count_until_triggered += 1 64 | else: 65 | voiced_frames.append(frame) 66 | if vad.is_speech(frame.bytes, sample_rate): 67 | count_unvoiced_tail = 0 68 | else: 69 | count_unvoiced_tail -= 1 70 | if voiced_frames: 71 | if count_unvoiced_tail != 0: 72 | voiced_frames = voiced_frames[:count_unvoiced_tail] 73 | vad_list = [] 74 | vad = webrtcvad.Vad(aggressiveness) 75 | for f in voiced_frames[:ring_buffer.maxlen]: 76 | vad_list.append(vad.is_speech(f.bytes, sample_rate)) 77 | else: 78 | print('Maybe unvoiced file.', name) 79 | vad_list = [] 80 | vad = webrtcvad.Vad(aggressiveness) 81 | for f in frames: 82 | vad_list.append(vad.is_speech(f.bytes, sample_rate)) 83 | voiced_frames = frames 84 | 85 | first_voice = vad_list.index(1) 86 | voiced_frames = voiced_frames[first_voice:] 87 | return b''.join([f.bytes for f in voiced_frames]) 88 | 89 | 90 | if __name__ == '__main__': 91 | """ 92 | Codes from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py 93 | """ 94 | parser = argparse.ArgumentParser(description='training script') 95 | parser.add_argument('--agressiveness', type=int, default='2', help='integer from 0 to 3. aggressiveness about filtering out non-speech, 3 is the most aggressive.') 96 | parser.add_argument('--buffer_ratio', type=float, default='0.3', help='ratio of speeched frames to trigger buffer.') 97 | args = parser.parse_args() 98 | 99 | aggressiveness = args.aggressiveness 100 | buffer_ratio = args.buffer_ratio 101 | 102 | rDirectory = '/data2/lyg0722/TTS_corpus/etri/wav_old/' 103 | wDirectory = '/data2/lyg0722/TTS_corpus/etri/wav/' 104 | 105 | for root, dirnames, filenames in os.walk(rDirectory): 106 | for filename in sorted(filenames): 107 | if filename[-4:] == '.wav': 108 | rf = os.path.join(root, filename) 109 | audio, sample_rate = read_wave(rf) 110 | frames = frame_generator(30, audio, sample_rate) 111 | segment = vad_collector(sample_rate, 30, 300, aggressiveness, buffer_ratio, frames, filename) 112 | wPath = str(wDirectory + filename) 113 | write_wave(wPath, segment, sample_rate) 114 | 115 | # TODO: use multiprocess to speed up. --------------------------------------------------------------------------------