├── .gitignore ├── LICENSE ├── README.md ├── fine-tune ├── distributed.py ├── hparams.py ├── inference.py ├── inference_embedding.py ├── inference_utils.py ├── logger.py ├── model │ ├── __init__.py │ ├── basic_layers.py │ ├── beam.py │ ├── decoder.py │ ├── layers.py │ ├── loss.py │ ├── model.py │ ├── penalties.py │ └── utils.py ├── multiproc.py ├── plotting_utils.py ├── reader │ ├── __init__.py │ ├── reader.py │ └── symbols.py ├── run.sh ├── train.py └── zero_embeddings.npy ├── pre-train ├── distributed.py ├── hparams.py ├── inference.py ├── logger.py ├── model │ ├── __init__.py │ ├── basic_layers.py │ ├── beam.py │ ├── decoder.py │ ├── layers.py │ ├── loss.py │ ├── model.py │ ├── penalties.py │ └── utils.py ├── multiproc.py ├── plotting_utils.py ├── reader │ ├── __init__.py │ ├── extract_features.py │ ├── reader.py │ └── symbols.py ├── run.sh └── train.py ├── requirements.txt └── struct.PNG /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoint_* 2 | events.* 3 | *.pdf 4 | *.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2019] [Jing-Xuan Zhang] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Non-parallel Seq2seq Voice Conversion 2 | 3 | Implementation code of [Non-Parallel Sequence-to-Sequence Voice Conversion with Disentangled Linguistic and Speaker Representations](https://arxiv.org/abs/1906.10508). 4 | 5 | For audio samples, please visit our [demo page](https://jxzhanggg.github.io/nonparaSeq2seqVC/). 6 | 7 | ![The structure overview of the model](struct.PNG) 8 | 9 | ## Dependencies 10 | 11 | * Python 3.6 12 | * PyTorch 1.0.1 13 | * CUDA 10.0 14 | 15 | ## Data 16 | 17 | It is recommended you download the [VCTK](http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html) and [CMU-ARCTIC](http://www.speech.cs.cmu.edu/cmu_arctic/packed/) datasets. 18 | 19 | ## Usage 20 | 21 | ### Installation 22 | 23 | Install Python dependencies. 24 | 25 | ```bash 26 | $ pip install -r requirements.txt 27 | ``` 28 | 29 | ### Feature Extraction 30 | 31 | #### Extract Mel-Spectrograms, Spectrograms and Phonemes 32 | 33 | You can use [`extract_features.py`](https://github.com/jxzhanggg/nonparaSeq2seqVC_code/blob/master/pre-train/reader/extract_features.py) 34 | 35 | 36 | ### Customize data reader 37 | 38 | Write a snippet of code to walk through the dataset for generating list file for train, valid and test set. 39 | 40 | Then you will need to modify the data reader to read your training data. The following are scripts you will need to modify. 41 | 42 | For pre-training: 43 | 44 | - [`reader.py`](https://github.com/jxzhanggg/nonparaSeq2seqVC_code/blob/master/pre-train/reader/reader.py) 45 | - [`symbols.py`](https://github.com/jxzhanggg/nonparaSeq2seqVC_code/blob/master/pre-train/reader/symbols.py) 46 | 47 | For fine-tuning: 48 | 49 | - [`reader.py`](https://github.com/jxzhanggg/nonparaSeq2seqVC_code/blob/master/fine-tune/reader/reader.py) 50 | - [`symbols.py`](https://github.com/jxzhanggg/nonparaSeq2seqVC_code/blob/master/fine-tune/reader/symbols.py) 51 | 52 | 53 | 54 | ### Pre-train the model 55 | 56 | Add correct paths to your local data, and run the bash script: 57 | 58 | ```bash 59 | $ cd pre-train 60 | $ bash run.sh 61 | ``` 62 | 63 | Run the inference code to generate audio samples on multi-speaker dataset. During inference, our model can be run on either TTS (using text inputs) or VC (using Mel-spectrogram inputs) mode. 64 | 65 | ```bash 66 | $ python inference.py 67 | ``` 68 | 69 | ### Fine-tune the model 70 | 71 | Fine-tune the model and generate audio samples on conversion pair. During inference, our model can be run on either TTS (using text inputs) or VC (using Mel-spectrogram inputs) mode. 72 | 73 | ```bash 74 | $ cd fine-tune 75 | $ bash run.sh 76 | ``` 77 | 78 | ## Training Time 79 | 80 | On a single NVIDIA 1080 Ti GPU, with a batch size of 32, pre-training on VCTK takes approximately 64 hours of wall-clock time. Fine-tuning on two speakers (500 utterances each speaker) with a batch size of 8 takes approximately 6 hours of wall-clock time. 81 | 82 | ## Citation 83 | 84 | If you use this code, please cite: 85 | ```bibtex 86 | @article{zhangnonpara2020, 87 | author={Jing-Xuan {Zhang} and Zhen-Hua {Ling} and Li-Rong {Dai}}, 88 | journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing}, 89 | title={Non-Parallel Sequence-to-Sequence Voice Conversion with Disentangled Linguistic and Speaker Representations}, 90 | year={2020}, 91 | volume={28}, 92 | number={1}, 93 | pages={540-552}} 94 | 95 | ``` 96 | 97 | ## Acknowledgements 98 | 99 | Part of code was adapted from the following project: 100 | * https://github.com/NVIDIA/tacotron2/ 101 | * https://github.com/r9y9/deepvoice3_pytorch 102 | -------------------------------------------------------------------------------- /fine-tune/distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.nn.modules import Module 4 | from torch.autograd import Variable 5 | 6 | def _flatten_dense_tensors(tensors): 7 | """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of 8 | same dense type. 9 | Since inputs are dense, the resulting tensor will be a concatenated 1D 10 | buffer. Element-wise operation on this buffer will be equivalent to 11 | operating individually. 12 | Arguments: 13 | tensors (Iterable[Tensor]): dense tensors to flatten. 14 | Returns: 15 | A contiguous 1D buffer containing input tensors. 16 | """ 17 | if len(tensors) == 1: 18 | return tensors[0].contiguous().view(-1) 19 | flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) 20 | return flat 21 | 22 | def _unflatten_dense_tensors(flat, tensors): 23 | """View a flat buffer using the sizes of tensors. Assume that tensors are of 24 | same dense type, and that flat is given by _flatten_dense_tensors. 25 | Arguments: 26 | flat (Tensor): flattened dense tensors to unflatten. 27 | tensors (Iterable[Tensor]): dense tensors whose sizes will be used to 28 | unflatten flat. 29 | Returns: 30 | Unflattened dense tensors with sizes same as tensors and values from 31 | flat. 32 | """ 33 | outputs = [] 34 | offset = 0 35 | for tensor in tensors: 36 | numel = tensor.numel() 37 | outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) 38 | offset += numel 39 | return tuple(outputs) 40 | 41 | 42 | ''' 43 | This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py 44 | launcher included with this example. It assumes that your run is using multiprocess with 1 45 | GPU/process, that the model is on the correct device, and that torch.set_device has been 46 | used to set the device. 47 | Parameters are broadcasted to the other processes on initialization of DistributedDataParallel, 48 | and will be allreduced at the finish of the backward pass. 49 | ''' 50 | class DistributedDataParallel(Module): 51 | 52 | def __init__(self, module): 53 | super(DistributedDataParallel, self).__init__() 54 | #fallback for PyTorch 0.3 55 | if not hasattr(dist, '_backend'): 56 | self.warn_on_half = True 57 | else: 58 | self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 59 | 60 | self.module = module 61 | 62 | for p in list(self.module.state_dict().values()): 63 | if not torch.is_tensor(p): 64 | continue 65 | dist.broadcast(p, 0) 66 | 67 | def allreduce_params(): 68 | if(self.needs_reduction): 69 | self.needs_reduction = False 70 | buckets = {} 71 | for param in self.module.parameters(): 72 | if param.requires_grad and param.grad is not None: 73 | tp = type(param.data) 74 | if tp not in buckets: 75 | buckets[tp] = [] 76 | buckets[tp].append(param) 77 | if self.warn_on_half: 78 | if torch.cuda.HalfTensor in buckets: 79 | print(("WARNING: gloo dist backend for half parameters may be extremely slow." + 80 | " It is recommended to use the NCCL backend in this case. This currently requires" + 81 | "PyTorch built from top of tree master.")) 82 | self.warn_on_half = False 83 | 84 | for tp in buckets: 85 | bucket = buckets[tp] 86 | grads = [param.grad.data for param in bucket] 87 | coalesced = _flatten_dense_tensors(grads) 88 | dist.all_reduce(coalesced) 89 | coalesced /= dist.get_world_size() 90 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 91 | buf.copy_(synced) 92 | 93 | for param in list(self.module.parameters()): 94 | def allreduce_hook(*unused): 95 | param._execution_engine.queue_callback(allreduce_params) 96 | if param.requires_grad: 97 | param.register_hook(allreduce_hook) 98 | 99 | def forward(self, *inputs, **kwargs): 100 | self.needs_reduction = True 101 | return self.module(*inputs, **kwargs) 102 | 103 | ''' 104 | def _sync_buffers(self): 105 | buffers = list(self.module._all_buffers()) 106 | if len(buffers) > 0: 107 | # cross-node buffer sync 108 | flat_buffers = _flatten_dense_tensors(buffers) 109 | dist.broadcast(flat_buffers, 0) 110 | for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): 111 | buf.copy_(synced) 112 | def train(self, mode=True): 113 | # Clear NCCL communicator and CUDA event cache of the default group ID, 114 | # These cache will be recreated at the later call. This is currently a 115 | # work-around for a potential NCCL deadlock. 116 | if dist._backend == dist.dist_backend.NCCL: 117 | dist._clear_group_cache() 118 | super(DistributedDataParallel, self).train(mode) 119 | self.module.train(mode) 120 | ''' 121 | ''' 122 | Modifies existing model to do gradient allreduce, but doesn't change class 123 | so you don't need "module" 124 | ''' 125 | def apply_gradient_allreduce(module): 126 | if not hasattr(dist, '_backend'): 127 | module.warn_on_half = True 128 | else: 129 | module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 130 | 131 | for p in list(module.state_dict().values()): 132 | if not torch.is_tensor(p): 133 | continue 134 | dist.broadcast(p, 0) 135 | 136 | def allreduce_params(): 137 | if(module.needs_reduction): 138 | module.needs_reduction = False 139 | buckets = {} 140 | for param in module.parameters(): 141 | if param.requires_grad and param.grad is not None: 142 | tp = type(param.data) 143 | if tp not in buckets: 144 | buckets[tp] = [] 145 | buckets[tp].append(param) 146 | if module.warn_on_half: 147 | if torch.cuda.HalfTensor in buckets: 148 | print(("WARNING: gloo dist backend for half parameters may be extremely slow." + 149 | " It is recommended to use the NCCL backend in this case. This currently requires" + 150 | "PyTorch built from top of tree master.")) 151 | module.warn_on_half = False 152 | 153 | for tp in buckets: 154 | bucket = buckets[tp] 155 | grads = [param.grad.data for param in bucket] 156 | coalesced = _flatten_dense_tensors(grads) 157 | dist.all_reduce(coalesced) 158 | coalesced /= dist.get_world_size() 159 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 160 | buf.copy_(synced) 161 | 162 | for param in list(module.parameters()): 163 | def allreduce_hook(*unused): 164 | Variable._execution_engine.queue_callback(allreduce_params) 165 | if param.requires_grad: 166 | param.register_hook(allreduce_hook) 167 | 168 | def set_needs_reduction(self, input, output): 169 | self.needs_reduction = True 170 | 171 | module.register_forward_hook(set_needs_reduction) 172 | return module -------------------------------------------------------------------------------- /fine-tune/hparams.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | #from text import symbols 3 | 4 | def create_hparams(hparams_string=None, verbose=False): 5 | """Create model hyperparameters. Parse nondefault from given string.""" 6 | 7 | hparams = tf.contrib.training.HParams( 8 | ################################ 9 | # Experiment Parameters # 10 | ################################ 11 | epochs=70, 12 | iters_per_checkpoint=100, 13 | seed=1234, 14 | dynamic_loss_scaling=True, 15 | fp16_run=False, 16 | distributed_run=False, 17 | dist_backend="nccl", 18 | dist_url="tcp://localhost:54321", 19 | cudnn_enabled=True, 20 | cudnn_benchmark=False, 21 | 22 | ################################ 23 | # Data Parameters # 24 | ################################ 25 | training_list='/home/jxzhang/Documents/DataSets/cmu_us_slt_arctic-0.95-release/list/train_non-parallel_slt_rms.list', 26 | validation_list='/home/jxzhang/Documents/DataSets/cmu_us_slt_arctic-0.95-release/list/eval_slt_rms.list', 27 | mel_mean_std='/home/jxzhang/Documents/DataSets/VCTK/mel_mean_std.npy', 28 | 29 | speaker_A='slt', 30 | speaker_B='rms', 31 | a_embedding_path='zero_embeddings.npy', 32 | b_embedding_path='zero_embeddings.npy', 33 | ################################ 34 | # Data Parameters # 35 | ################################ 36 | n_mel_channels=80, 37 | n_spc_channels=1025, 38 | n_symbols=41, # 39 | pretrain_n_speakers=99, # 40 | n_speakers=2, 41 | predict_spectrogram=False, 42 | 43 | ################################ 44 | # Model Parameters # 45 | ################################ 46 | 47 | symbols_embedding_dim=512, 48 | 49 | # Text Encoder parameters 50 | encoder_kernel_size=5, 51 | encoder_n_convolutions=3, 52 | encoder_embedding_dim=512, 53 | text_encoder_dropout=0.5, 54 | 55 | # Audio Encoder parameters 56 | spemb_input=False, 57 | n_frames_per_step_encoder=2, 58 | audio_encoder_hidden_dim=512, 59 | AE_attention_dim=128, 60 | AE_attention_location_n_filters=32, 61 | AE_attention_location_kernel_size=51, 62 | beam_width=10, 63 | 64 | # hidden activation 65 | # relu linear tanh 66 | hidden_activation='tanh', 67 | 68 | #Speaker Encoder parameters 69 | speaker_encoder_hidden_dim=256, 70 | speaker_encoder_dropout=0.2, 71 | speaker_embedding_dim=128, 72 | 73 | #Text Classifier parameters 74 | #text_classifier_hidden_dim=256, 75 | 76 | #Speaker Classifier parameters 77 | SC_hidden_dim=512, 78 | SC_n_convolutions=3, 79 | SC_kernel_size=5, 80 | 81 | # Decoder parameters 82 | feed_back_last=True, 83 | n_frames_per_step_decoder=2, 84 | decoder_rnn_dim=512, 85 | prenet_dim=[256,256], 86 | max_decoder_steps=1000, 87 | stop_threshold=0.5, 88 | 89 | # Attention parameters 90 | attention_rnn_dim=512, 91 | attention_dim=128, 92 | 93 | # Location Layer parameters 94 | attention_location_n_filters=32, 95 | attention_location_kernel_size=17, 96 | 97 | # PostNet parameters 98 | postnet_n_convolutions=5, 99 | postnet_dim=512, 100 | postnet_kernel_size=5, 101 | postnet_dropout=0.5, 102 | 103 | ################################ 104 | # Optimization Hyperparameters # 105 | ################################ 106 | use_saved_learning_rate=False, 107 | learning_rate=1e-3, 108 | weight_decay=1e-6, 109 | grad_clip_thresh=5.0, 110 | batch_size=8, 111 | warmup=7, 112 | decay_rate=0.5, 113 | decay_every=7, 114 | 115 | contrastive_loss_w=30.0, 116 | speaker_encoder_loss_w=0., 117 | text_classifier_loss_w=1.0, 118 | speaker_adversial_loss_w=0.2, 119 | speaker_classifier_loss_w=1.0, 120 | 121 | ) 122 | 123 | if hparams_string: 124 | tf.logging.info('Parsing command line hparams: %s', hparams_string) 125 | hparams.parse(hparams_string) 126 | 127 | if verbose: 128 | tf.logging.info('Final parsed hparams: %s', list(hparams.values())) 129 | 130 | return hparams 131 | -------------------------------------------------------------------------------- /fine-tune/inference.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import matplotlib.pylab as plt 4 | 5 | 6 | import os 7 | import librosa 8 | import numpy as np 9 | import torch 10 | import argparse 11 | from torch.utils.data import DataLoader 12 | 13 | from reader import TextMelIDLoader, TextMelIDCollate, id2ph, id2sp 14 | from hparams import create_hparams 15 | from model import Parrot, lcm 16 | from train import load_model 17 | from inference_utils import plot_data, levenshteinDistance, recover_wav 18 | import scipy.io.wavfile 19 | 20 | AA_tts, BB_tts, AB_vc, BA_vc = False, False, True, True 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('-c', '--checkpoint_path', type=str, 24 | help='directory to save checkpoints') 25 | parser.add_argument('--num', type=int, default=10, 26 | required=False, help='num of samples to be generated') 27 | parser.add_argument('--hparams', type=str, 28 | required=False, help='comma separated name=value pairs') 29 | args = parser.parse_args() 30 | 31 | 32 | hparams = create_hparams(args.hparams) 33 | 34 | test_list = hparams.validation_list 35 | checkpoint_path=args.checkpoint_path 36 | gen_num = args.num 37 | ISMEL=(not hparams.predict_spectrogram) 38 | 39 | 40 | model = load_model(hparams) 41 | 42 | 43 | model.load_state_dict(torch.load(checkpoint_path)['state_dict'], strict=False) 44 | _ = model.eval() 45 | 46 | 47 | 48 | train_set_A = TextMelIDLoader(hparams.training_list, hparams.mel_mean_std, 49 | hparams.speaker_A,hparams.speaker_B, 50 | shuffle=False,pids=[hparams.speaker_A]) 51 | 52 | train_set_B = TextMelIDLoader(hparams.training_list, hparams.mel_mean_std, 53 | hparams.speaker_A,hparams.speaker_B, 54 | shuffle=False,pids=[hparams.speaker_B]) 55 | 56 | test_set_A = TextMelIDLoader(test_list, hparams.mel_mean_std, 57 | hparams.speaker_A,hparams.speaker_B, 58 | shuffle=False,pids=[hparams.speaker_A]) 59 | 60 | test_set_B = TextMelIDLoader(test_list, hparams.mel_mean_std, 61 | hparams.speaker_A,hparams.speaker_B, 62 | shuffle=False,pids=[hparams.speaker_B]) 63 | 64 | sample_list_A = test_set_A.file_path_list 65 | sample_list_B = test_set_B.file_path_list 66 | 67 | collate_fn = TextMelIDCollate(lcm(hparams.n_frames_per_step_encoder, 68 | hparams.n_frames_per_step_decoder)) 69 | 70 | test_loader_A = DataLoader(test_set_A, num_workers=1, shuffle=False, 71 | sampler=None, 72 | batch_size=1, pin_memory=False, 73 | drop_last=False, collate_fn=collate_fn) 74 | 75 | test_loader_B = DataLoader(test_set_B, num_workers=1, shuffle=False, 76 | sampler=None, 77 | batch_size=1, pin_memory=False, 78 | drop_last=False, collate_fn=collate_fn) 79 | 80 | 81 | id2sp[0] = hparams.speaker_A 82 | id2sp[1] = hparams.speaker_B 83 | 84 | _, mel, __, speaker_id = train_set_A[0] 85 | reference_mel_A = speaker_id.cuda() 86 | ref_sp_A = id2sp[speaker_id.item()] 87 | 88 | _, mel, __, speaker_id = train_set_B[0] 89 | reference_mel_B = speaker_id.cuda() 90 | ref_sp_B = id2sp[speaker_id.item()] 91 | 92 | 93 | 94 | def get_path(input_text, A, B): 95 | task = 'tts' if input_text else 'vc' 96 | 97 | path_save = os.path.join(checkpoint_path.replace('checkpoint', 'test'), task) 98 | 99 | path_save += '_%s_to_%s'%(A, B) 100 | 101 | if not os.path.exists(os.path.join(path_save,'wav_mel')): 102 | os.makedirs(os.path.join(path_save,'wav_mel')) 103 | 104 | if not os.path.exists(os.path.join(path_save,'mel')): 105 | os.makedirs(os.path.join(path_save,'mel')) 106 | 107 | if not os.path.exists(os.path.join(path_save,'hid')): 108 | os.makedirs(os.path.join(path_save,'hid')) 109 | 110 | if not os.path.exists(os.path.join(path_save,'ali')): 111 | os.makedirs(os.path.join(path_save,'ali')) 112 | 113 | print(path_save) 114 | return path_save 115 | 116 | 117 | def generate(loader, reference_mel, beam_width, path_save, ref_sp, 118 | sample_list, num=10, input_text=False): 119 | 120 | with torch.no_grad(): 121 | errs = 0 122 | totalphs = 0 123 | 124 | for i, batch in enumerate(loader): 125 | if i == num: 126 | break 127 | 128 | sample_id = sample_list[i].split('/')[-1][9:17+4] 129 | print(('%d index %s, decoding ...'%(i,sample_id))) 130 | 131 | x, y = model.parse_batch(batch) 132 | predicted_mel, post_output, predicted_stop, alignments, \ 133 | text_hidden, audio_seq2seq_hidden, audio_seq2seq_phids, audio_seq2seq_alignments, \ 134 | speaker_id = model.inference(x, input_text, reference_mel, beam_width) 135 | 136 | post_output = post_output.data.cpu().numpy()[0] 137 | alignments = alignments.data.cpu().numpy()[0].T 138 | audio_seq2seq_alignments = audio_seq2seq_alignments.data.cpu().numpy()[0].T 139 | 140 | text_hidden = text_hidden.data.cpu().numpy()[0].T #-> [hidden_dim, max_text_len] 141 | audio_seq2seq_hidden = audio_seq2seq_hidden.data.cpu().numpy()[0].T 142 | audio_seq2seq_phids = audio_seq2seq_phids.data.cpu().numpy()[0] # [T + 1] 143 | speaker_id = speaker_id.data.cpu().numpy()[0] # scalar 144 | 145 | task = 'TTS' if input_text else 'VC' 146 | 147 | recover_wav(post_output, 148 | os.path.join(path_save, 'wav_mel/Wav_%s_ref_%s_%s.wav'%(sample_id, ref_sp, task)), 149 | hparams.mel_mean_std, 150 | ismel=ISMEL) 151 | 152 | post_output_path = os.path.join(path_save, 'mel/Mel_%s_ref_%s_%s.npy'%(sample_id, ref_sp, task)) 153 | np.save(post_output_path, post_output) 154 | 155 | plot_data([alignments, audio_seq2seq_alignments], 156 | os.path.join(path_save, 'ali/Ali_%s_ref_%s_%s.pdf'%(sample_id, ref_sp, task))) 157 | 158 | plot_data([np.hstack([text_hidden, audio_seq2seq_hidden])], 159 | os.path.join(path_save, 'hid/Hid_%s_ref_%s_%s.pdf'%(sample_id, ref_sp, task))) 160 | 161 | audio_seq2seq_phids = [id2ph[id] for id in audio_seq2seq_phids[:-1]] 162 | target_text = y[0].data.cpu().numpy()[0] 163 | target_text = [id2ph[id] for id in target_text[:]] 164 | 165 | if not input_text: 166 | #print 'Sounds like %s, Decoded text is '%(id2sp[speaker_id]) 167 | print(audio_seq2seq_phids) 168 | print(target_text) 169 | 170 | err = levenshteinDistance(audio_seq2seq_phids, target_text) 171 | print(err, len(target_text)) 172 | 173 | errs += err 174 | totalphs += len(target_text) 175 | 176 | #print float(errs)/float(totalphs) 177 | return float(errs)/float(totalphs) 178 | 179 | 180 | ####### TTS A - A ############ 181 | 182 | if AA_tts: 183 | path_save = get_path(True, ref_sp_A, ref_sp_A) 184 | generate(test_loader_A, reference_mel_A, hparams.beam_width, 185 | path_save, ref_sp_A, sample_list_A, num=gen_num, input_text=True) 186 | 187 | ####### TTS B - B ############ 188 | if BB_tts: 189 | path_save = get_path(True, ref_sp_B, ref_sp_B) 190 | generate(test_loader_B, reference_mel_B, hparams.beam_width, 191 | path_save, ref_sp_B, sample_list_B, num=gen_num, input_text=True) 192 | 193 | ####### VC A - B ############# 194 | if AB_vc: 195 | path_save = get_path(False, ref_sp_A, ref_sp_B) 196 | per_AB = generate(test_loader_A, reference_mel_B, hparams.beam_width, 197 | path_save, ref_sp_B, sample_list_A, num=gen_num, input_text=False) 198 | print(('PER %s-to-%s is %.4f'%(ref_sp_A, ref_sp_B, per_AB))) 199 | 200 | ####### VC B - A ############# 201 | if BA_vc: 202 | path_save = get_path(False, ref_sp_B, ref_sp_A) 203 | per_BA = generate(test_loader_B, reference_mel_A, hparams.beam_width, 204 | path_save, ref_sp_A, sample_list_B, num=gen_num, input_text=False) 205 | print(('PER %s-to-%s is %.4f'%(ref_sp_B, ref_sp_A, per_BA))) 206 | 207 | -------------------------------------------------------------------------------- /fine-tune/inference_embedding.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import argparse 5 | 6 | from hparams import create_hparams 7 | from model import lcm 8 | from train import load_model 9 | from torch.utils.data import DataLoader 10 | from reader import TextMelIDLoader, TextMelIDCollate, id2sp 11 | from inference_utils import plot_data 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-c', '--checkpoint_path', type=str, 15 | help='directory to save checkpoints') 16 | parser.add_argument('--hparams', type=str, 17 | required=False, help='comma separated name=value pairs') 18 | args = parser.parse_args() 19 | 20 | checkpoint_path=args.checkpoint_path 21 | 22 | hparams = create_hparams(args.hparams) 23 | 24 | model = load_model(hparams) 25 | model.load_state_dict(torch.load(checkpoint_path)['state_dict'], strict=False) 26 | _ = model.eval() 27 | 28 | 29 | def gen_embedding(speaker): 30 | 31 | training_list = hparams.training_list 32 | 33 | train_set_A = TextMelIDLoader(training_list, hparams.mel_mean_std, hparams.speaker_A, 34 | hparams.speaker_B, 35 | shuffle=False,pids=[speaker]) 36 | 37 | collate_fn = TextMelIDCollate(lcm(hparams.n_frames_per_step_encoder, 38 | hparams.n_frames_per_step_decoder)) 39 | 40 | train_loader_A = DataLoader(train_set_A, num_workers=1, shuffle=False, 41 | sampler=None, 42 | batch_size=1, pin_memory=False, 43 | drop_last=True, collate_fn=collate_fn) 44 | 45 | with torch.no_grad(): 46 | 47 | speaker_embeddings = [] 48 | 49 | for i,batch in enumerate(train_loader_A): 50 | #print i 51 | x, y = model.parse_batch(batch) 52 | text_input_padded, mel_padded, text_lengths, mel_lengths, speaker_id = x 53 | speaker_id, speaker_embedding = model.speaker_encoder.inference(mel_padded) 54 | 55 | speaker_embedding = speaker_embedding.data.cpu().numpy() 56 | speaker_embeddings.append(speaker_embedding) 57 | 58 | speaker_embeddings = np.vstack(speaker_embeddings) 59 | 60 | print(speaker_embeddings.shape) 61 | if not os.path.exists('outdir/embeddings'): 62 | os.makedirs('outdir/embeddings') 63 | 64 | np.save('outdir/embeddings/%s.npy'%speaker, speaker_embeddings) 65 | plot_data([speaker_embeddings], 66 | 'outdir/embeddings/%s.pdf'%speaker) 67 | 68 | 69 | print('Generating embedding of %s ...'%hparams.speaker_A) 70 | gen_embedding(hparams.speaker_A) 71 | 72 | print('Generating embedding of %s ...'%hparams.speaker_B) 73 | gen_embedding(hparams.speaker_B) 74 | -------------------------------------------------------------------------------- /fine-tune/inference_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import matplotlib.pylab as plt 4 | 5 | import numpy as np 6 | import librosa 7 | import scipy.io.wavfile 8 | 9 | def plot_data(data, fn, figsize=(12, 4)): 10 | fig, axes = plt.subplots(1, len(data), figsize=figsize) 11 | for i in range(len(data)): 12 | if len(data) == 1: 13 | ax = axes 14 | else: 15 | ax = axes[i] 16 | g = ax.imshow(data[i], aspect='auto', origin='bottom', 17 | interpolation='none') 18 | plt.colorbar(g, ax=ax) 19 | plt.savefig(fn) 20 | 21 | 22 | 23 | def levenshteinDistance(s1, s2): 24 | if len(s1) > len(s2): 25 | s1, s2 = s2, s1 26 | 27 | distances = list(range(len(s1) + 1)) 28 | for i2, c2 in enumerate(s2): 29 | distances_ = [i2+1] 30 | for i1, c1 in enumerate(s1): 31 | if c1 == c2: 32 | distances_.append(distances[i1]) 33 | else: 34 | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) 35 | distances = distances_ 36 | return distances[-1] 37 | 38 | 39 | def recover_wav(mel, wav_path, mel_mean_std, ismel=False, 40 | n_fft = 2048,win_length=800, hop_length=200): 41 | if ismel: 42 | mean, std = np.load(mel_mean_std) 43 | else: 44 | mean, std = np.load(mel_mean_std.replace('mel','spec')) 45 | 46 | mean = mean[:,None] 47 | std = std[:,None] 48 | mel = 1.2 * mel * std + mean 49 | mel = np.exp(mel) 50 | 51 | if ismel: 52 | filters = librosa.filters.mel(sr=16000, n_fft=2048, n_mels=80) 53 | inv_filters = np.linalg.pinv(filters) 54 | spec = np.dot(inv_filters, mel) 55 | else: 56 | spec = mel 57 | 58 | def _griffin_lim(stftm_matrix, shape, max_iter=50): 59 | y = np.random.random(shape) 60 | for i in range(max_iter): 61 | stft_matrix = librosa.core.stft(y, n_fft=n_fft, win_length=win_length, hop_length=hop_length) 62 | stft_matrix = stftm_matrix * stft_matrix / np.abs(stft_matrix) 63 | y = librosa.core.istft(stft_matrix, win_length=win_length, hop_length=hop_length) 64 | return y 65 | 66 | shape = spec.shape[1] * hop_length - hop_length + 1 67 | 68 | y = _griffin_lim(spec, shape) 69 | scipy.io.wavfile.write(wav_path, 16000, y) 70 | return y -------------------------------------------------------------------------------- /fine-tune/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch.nn.functional as F 4 | from tensorboardX import SummaryWriter 5 | from plotting_utils import plot_alignment_to_numpy, plot_spectrogram_to_numpy, plot_alignment 6 | from plotting_utils import plot_gate_outputs_to_numpy 7 | 8 | 9 | class ParrotLogger(SummaryWriter): 10 | def __init__(self, logdir, ali_path='ali'): 11 | super(ParrotLogger, self).__init__(logdir) 12 | ali_path = os.path.join(logdir, ali_path) 13 | if not os.path.exists(ali_path): 14 | os.makedirs(ali_path) 15 | self.ali_path = ali_path 16 | 17 | def log_training(self, reduced_loss, reduced_losses, reduced_acces, grad_norm, learning_rate, duration, 18 | iteration): 19 | 20 | self.add_scalar("training.loss", reduced_loss, iteration) 21 | self.add_scalar("training.loss.recon", reduced_losses[0], iteration) 22 | self.add_scalar("training.loss.recon_post", reduced_losses[1], iteration) 23 | self.add_scalar("training.loss.stop", reduced_losses[2], iteration) 24 | self.add_scalar("training.loss.contr", reduced_losses[3], iteration) 25 | self.add_scalar("training.loss.spenc", reduced_losses[4], iteration) 26 | self.add_scalar("training.loss.spcla", reduced_losses[5], iteration) 27 | self.add_scalar("training.loss.texcl", reduced_losses[6], iteration) 28 | self.add_scalar("training.loss.spadv", reduced_losses[7], iteration) 29 | 30 | self.add_scalar("grad.norm", grad_norm, iteration) 31 | self.add_scalar("learning.rate", learning_rate, iteration) 32 | self.add_scalar("duration", duration, iteration) 33 | 34 | 35 | self.add_scalar('training.acc.spenc', reduced_acces[0], iteration) 36 | self.add_scalar('training.acc.spcla', reduced_acces[1], iteration) 37 | self.add_scalar('training.acc.texcl', reduced_acces[2], iteration) 38 | 39 | def log_validation(self, reduced_loss, reduced_losses, reduced_acces, model, y, y_pred, iteration, task): 40 | 41 | self.add_scalar('validation.loss.%s'%task, reduced_loss, iteration) 42 | self.add_scalar("validation.loss.%s.recon"%task, reduced_losses[0], iteration) 43 | self.add_scalar("validation.loss.%s.recon_post"%task, reduced_losses[1], iteration) 44 | self.add_scalar("validation.loss.%s.stop"%task, reduced_losses[2], iteration) 45 | self.add_scalar("validation.loss.%s.contr"%task, reduced_losses[3], iteration) 46 | self.add_scalar("validation.loss.%s.spenc"%task, reduced_losses[4], iteration) 47 | self.add_scalar("validation.loss.%s.spcla"%task, reduced_losses[5], iteration) 48 | self.add_scalar("validation.loss.%s.texcl"%task, reduced_losses[6], iteration) 49 | self.add_scalar("validation.loss.%s.spadv"%task, reduced_losses[7], iteration) 50 | 51 | self.add_scalar('validation.acc.%s.spenc'%task, reduced_acces[0], iteration) 52 | self.add_scalar('validation.acc.%s.spcla'%task, reduced_acces[1], iteration) 53 | self.add_scalar('validatoin.acc.%s.texcl'%task, reduced_acces[2], iteration) 54 | 55 | predicted_mel, post_output, predicted_stop, alignments, \ 56 | text_hidden, mel_hidden, text_logit_from_mel_hidden, \ 57 | audio_seq2seq_alignments, \ 58 | speaker_logit_from_mel_hidden, \ 59 | text_lengths, mel_lengths = y_pred 60 | 61 | text_target, mel_target, spc_target, speaker_target, stop_target = y 62 | 63 | stop_target = stop_target.reshape(stop_target.size(0), -1, int(stop_target.size(1)/predicted_stop.size(1))) 64 | stop_target = stop_target[:,:,0] 65 | 66 | # plot distribution of parameters 67 | #for tag, value in model.named_parameters(): 68 | # tag = tag.replace('.', '/') 69 | # self.add_histogram(tag, value.data.cpu().numpy(), iteration) 70 | 71 | # plot alignment, mel target and predicted, stop target and predicted 72 | idx = random.randint(0, alignments.size(0) - 1) 73 | 74 | alignments = alignments.data.cpu().numpy() 75 | audio_seq2seq_alignments = audio_seq2seq_alignments.data.cpu().numpy() 76 | 77 | self.add_image( 78 | "%s.alignment"%task, 79 | plot_alignment_to_numpy(alignments[idx].T), 80 | iteration, dataformats='HWC') 81 | 82 | # plot more alignments 83 | plot_alignment(alignments[:4], self.ali_path+'/step-%d-%s.pdf'%(iteration, task)) 84 | 85 | self.add_image( 86 | "%s.audio_seq2seq_alignment"%task, 87 | plot_alignment_to_numpy(audio_seq2seq_alignments[idx].T), 88 | iteration, dataformats='HWC') 89 | 90 | self.add_image( 91 | "%s.mel_target"%task, 92 | plot_spectrogram_to_numpy(mel_target[idx].data.cpu().numpy()), 93 | iteration, dataformats='HWC') 94 | 95 | self.add_image( 96 | "%s.mel_predicted"%task, 97 | plot_spectrogram_to_numpy(predicted_mel[idx].data.cpu().numpy()), 98 | iteration, dataformats='HWC') 99 | 100 | self.add_image( 101 | "%s.spc_target"%task, 102 | plot_spectrogram_to_numpy(spc_target[idx].data.cpu().numpy()), 103 | iteration, dataformats='HWC') 104 | 105 | self.add_image( 106 | "%s.post_predicted"%task, 107 | plot_spectrogram_to_numpy(post_output[idx].data.cpu().numpy()), 108 | iteration, dataformats='HWC') 109 | 110 | self.add_image( 111 | "%s.stop"%task, 112 | plot_gate_outputs_to_numpy( 113 | stop_target[idx].data.cpu().numpy(), 114 | F.sigmoid(predicted_stop[idx]).data.cpu().numpy()), 115 | iteration, dataformats='HWC') 116 | -------------------------------------------------------------------------------- /fine-tune/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Parrot 2 | from .loss import ParrotLoss 3 | from .utils import lcm,gcd -------------------------------------------------------------------------------- /fine-tune/model/basic_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def tile(x, count, dim=0): 7 | """ 8 | Tiles x on dimension dim count times. 9 | """ 10 | perm = list(range(len(x.size()))) 11 | if dim != 0: 12 | perm[0], perm[dim] = perm[dim], perm[0] 13 | x = x.permute(perm).contiguous() 14 | out_size = list(x.size()) 15 | out_size[0] *= count 16 | batch = x.size(0) 17 | x = x.view(batch, -1) \ 18 | .transpose(0, 1) \ 19 | .repeat(count, 1) \ 20 | .transpose(0, 1) \ 21 | .contiguous() \ 22 | .view(*out_size) 23 | if dim != 0: 24 | x = x.permute(perm).contiguous() 25 | return x 26 | 27 | 28 | def sort_batch(data, lengths): 29 | ''' 30 | sort data by length 31 | sorted_data[initial_index] == data 32 | ''' 33 | sorted_lengths, sorted_index = lengths.sort(0, descending=True) 34 | sorted_data = data[sorted_index] 35 | _, initial_index = sorted_index.sort(0, descending=False) 36 | 37 | return sorted_data, sorted_lengths, initial_index 38 | 39 | 40 | class LinearNorm(torch.nn.Module): 41 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 42 | super(LinearNorm, self).__init__() 43 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 44 | 45 | torch.nn.init.xavier_uniform_( 46 | self.linear_layer.weight, 47 | gain=torch.nn.init.calculate_gain(w_init_gain)) 48 | 49 | def forward(self, x): 50 | return self.linear_layer(x) 51 | 52 | 53 | class ConvNorm(torch.nn.Module): 54 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 55 | padding=None, dilation=1, bias=True, w_init_gain='linear', param=None): 56 | super(ConvNorm, self).__init__() 57 | if padding is None: 58 | assert(kernel_size % 2 == 1) 59 | padding = int(dilation * (kernel_size - 1) / 2) 60 | 61 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 62 | kernel_size=kernel_size, stride=stride, 63 | padding=padding, dilation=dilation, 64 | bias=bias) 65 | 66 | torch.nn.init.xavier_uniform_( 67 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) 68 | 69 | def forward(self, signal): 70 | conv_signal = self.conv(signal) 71 | return conv_signal 72 | 73 | 74 | class Prenet(nn.Module): 75 | def __init__(self, in_dim, sizes): 76 | super(Prenet, self).__init__() 77 | in_sizes = [in_dim] + sizes[:-1] 78 | self.layers = nn.ModuleList( 79 | [LinearNorm(in_size, out_size, bias=False) 80 | for (in_size, out_size) in zip(in_sizes, sizes)]) 81 | 82 | def forward(self, x): 83 | for linear in self.layers: 84 | x = F.dropout(F.relu(linear(x)), p=0.5, training=True) 85 | return x 86 | 87 | 88 | class LocationLayer(nn.Module): 89 | def __init__(self, attention_n_filters, attention_kernel_size, 90 | attention_dim): 91 | super(LocationLayer, self).__init__() 92 | padding = int((attention_kernel_size - 1) / 2) 93 | self.location_conv = ConvNorm(2, attention_n_filters, 94 | kernel_size=attention_kernel_size, 95 | padding=padding, bias=False, stride=1, 96 | dilation=1) 97 | self.location_dense = LinearNorm(attention_n_filters, attention_dim, 98 | bias=False, w_init_gain='tanh') 99 | 100 | def forward(self, attention_weights_cat): 101 | processed_attention = self.location_conv(attention_weights_cat) 102 | processed_attention = processed_attention.transpose(1, 2) 103 | processed_attention = self.location_dense(processed_attention) 104 | return processed_attention 105 | 106 | 107 | class Attention(nn.Module): 108 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, 109 | attention_location_n_filters, attention_location_kernel_size): 110 | super(Attention, self).__init__() 111 | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, 112 | bias=False, w_init_gain='tanh') 113 | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, 114 | w_init_gain='tanh') 115 | self.v = LinearNorm(attention_dim, 1, bias=False) 116 | self.location_layer = LocationLayer(attention_location_n_filters, 117 | attention_location_kernel_size, 118 | attention_dim) 119 | self.score_mask_value = -float("inf") 120 | 121 | def get_alignment_energies(self, query, processed_memory, 122 | attention_weights_cat): 123 | """ 124 | PARAMS 125 | ------ 126 | query: decoder output (batch, n_mel_channels * n_frames_per_step) 127 | processed_memory: processed encoder outputs (B, T_in, attention_dim) 128 | attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) 129 | RETURNS 130 | ------- 131 | alignment (batch, max_time) 132 | """ 133 | 134 | processed_query = self.query_layer(query.unsqueeze(1)) 135 | processed_attention_weights = self.location_layer(attention_weights_cat) 136 | energies = self.v(torch.tanh( 137 | processed_query + processed_attention_weights + processed_memory)) 138 | 139 | energies = energies.squeeze(-1) 140 | return energies 141 | 142 | def forward(self, attention_hidden_state, memory, processed_memory, 143 | attention_weights_cat, mask): 144 | """ 145 | PARAMS 146 | ------ 147 | attention_hidden_state: attention rnn last output 148 | memory: encoder outputs 149 | processed_memory: processed encoder outputs 150 | attention_weights_cat: previous and cummulative attention weights 151 | mask: binary mask for padded data 152 | """ 153 | alignment = self.get_alignment_energies( 154 | attention_hidden_state, processed_memory, attention_weights_cat) 155 | 156 | if mask is not None: 157 | alignment.data.masked_fill_(mask, self.score_mask_value) 158 | 159 | attention_weights = F.softmax(alignment, dim=1) 160 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) 161 | attention_context = attention_context.squeeze(1) 162 | 163 | return attention_context, attention_weights 164 | 165 | 166 | class ForwardAttentionV2(nn.Module): 167 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, 168 | attention_location_n_filters, attention_location_kernel_size): 169 | super(ForwardAttentionV2, self).__init__() 170 | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, 171 | bias=False, w_init_gain='tanh') 172 | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, 173 | w_init_gain='tanh') 174 | self.v = LinearNorm(attention_dim, 1, bias=False) 175 | self.location_layer = LocationLayer(attention_location_n_filters, 176 | attention_location_kernel_size, 177 | attention_dim) 178 | self.score_mask_value = -float(1e20) 179 | 180 | def get_alignment_energies(self, query, processed_memory, 181 | attention_weights_cat): 182 | """ 183 | PARAMS 184 | ------ 185 | query: decoder output (batch, n_mel_channels * n_frames_per_step) 186 | processed_memory: processed encoder outputs (B, T_in, attention_dim) 187 | attention_weights_cat: prev. and cumulative att weights (B, 2, max_time) 188 | RETURNS 189 | ------- 190 | alignment (batch, max_time) 191 | """ 192 | 193 | processed_query = self.query_layer(query.unsqueeze(1)) 194 | processed_attention_weights = self.location_layer(attention_weights_cat) 195 | energies = self.v(torch.tanh( 196 | processed_query + processed_attention_weights + processed_memory)) 197 | 198 | energies = energies.squeeze(-1) 199 | return energies 200 | 201 | def forward(self, attention_hidden_state, memory, processed_memory, 202 | attention_weights_cat, mask, log_alpha): 203 | """ 204 | PARAMS 205 | ------ 206 | attention_hidden_state: attention rnn last output 207 | memory: encoder outputs 208 | processed_memory: processed encoder outputs 209 | attention_weights_cat: previous and cummulative attention weights 210 | mask: binary mask for padded data 211 | """ 212 | log_energy = self.get_alignment_energies( 213 | attention_hidden_state, processed_memory, attention_weights_cat) 214 | 215 | #log_energy = 216 | 217 | if mask is not None: 218 | log_energy.data.masked_fill_(mask, self.score_mask_value) 219 | 220 | #attention_weights = F.softmax(alignment, dim=1) 221 | 222 | #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME] 223 | #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1] 224 | 225 | #log_total_score = log_alpha + content_score 226 | 227 | #previous_attention_weights = attention_weights_cat[:,0,:] 228 | 229 | log_alpha_shift_padded = [] 230 | max_time = log_energy.size(1) 231 | for sft in range(2): 232 | shifted = log_alpha[:,:max_time-sft] 233 | shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value) 234 | log_alpha_shift_padded.append(shift_padded.unsqueeze(2)) 235 | 236 | biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2) 237 | 238 | log_alpha_new = biased + log_energy 239 | 240 | attention_weights = F.softmax(log_alpha_new, dim=1) 241 | 242 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) 243 | attention_context = attention_context.squeeze(1) 244 | 245 | return attention_context, attention_weights, log_alpha_new -------------------------------------------------------------------------------- /fine-tune/model/beam.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | from .penalties import PenaltyBuilder 5 | 6 | 7 | 8 | class Beam(object): 9 | """ 10 | ''' 11 | adapt from opennmt 12 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/translate/beam.py 13 | ''' 14 | 15 | Class for managing the internals of the beam search process. 16 | Takes care of beams, back pointers, and scores. 17 | Args: 18 | size (int): beam size 19 | pad, bos, eos (int): indices of padding, beginning, and ending. 20 | n_best (int): nbest size to use 21 | cuda (bool): use gpu 22 | global_scorer (:obj:`GlobalScorer`) 23 | """ 24 | 25 | def __init__(self, size, pad, bos, eos, 26 | n_best=1, cuda=False, 27 | global_scorer=None, 28 | min_length=0, 29 | stepwise_penalty=False, 30 | block_ngram_repeat=0, 31 | exclusion_tokens=set()): 32 | 33 | self.size = size 34 | self.tt = torch.cuda if cuda else torch 35 | 36 | # The score for each translation on the beam. 37 | self.scores = self.tt.FloatTensor(size).zero_() 38 | self.all_scores = [] 39 | 40 | # The backpointers at each time-step. 41 | self.prev_ks = [] 42 | 43 | # The outputs at each time-step. 44 | self.next_ys = [self.tt.LongTensor(size) 45 | .fill_(pad)] 46 | self.next_ys[0][0] = bos 47 | 48 | # Has EOS topped the beam yet. 49 | self._eos = eos 50 | self.eos_top = False 51 | 52 | # The attentions (matrix) for each time. 53 | self.attn = [] 54 | self.hidden = [] 55 | 56 | # Time and k pair for finished. 57 | self.finished = [] 58 | self.n_best = n_best 59 | 60 | # Information for global scoring. 61 | self.global_scorer = global_scorer 62 | self.global_state = {} 63 | 64 | # Minimum prediction length 65 | self.min_length = min_length 66 | 67 | # Apply Penalty at every step 68 | self.stepwise_penalty = stepwise_penalty 69 | self.block_ngram_repeat = block_ngram_repeat 70 | self.exclusion_tokens = exclusion_tokens 71 | 72 | def get_current_state(self): 73 | "Get the outputs for the current timestep." 74 | return self.next_ys[-1] 75 | 76 | def get_current_origin(self): 77 | "Get the backpointers for the current timestep." 78 | return self.prev_ks[-1] 79 | 80 | def advance(self, word_probs, attn_out, hidden): 81 | """ 82 | Given prob over words for every last beam `wordLk` and attention 83 | `attn_out`: Compute and update the beam search. 84 | Parameters: 85 | * `word_probs`- probs of advancing from the last step (K x words) 86 | * `attn_out`- attention at the last step 87 | Returns: True if beam search is complete. 88 | """ 89 | num_words = word_probs.size(1) 90 | if self.stepwise_penalty: 91 | self.global_scorer.update_score(self, attn_out) 92 | # force the output to be longer than self.min_length 93 | cur_len = len(self.next_ys) 94 | if cur_len < self.min_length: 95 | for k in range(len(word_probs)): 96 | word_probs[k][self._eos] = -1e20 97 | # Sum the previous scores. 98 | if len(self.prev_ks) > 0: 99 | beam_scores = word_probs + self.scores.unsqueeze(1) 100 | # Don't let EOS have children. 101 | for i in range(self.next_ys[-1].size(0)): 102 | if self.next_ys[-1][i] == self._eos: 103 | beam_scores[i] = -1e20 104 | 105 | # Block ngram repeats 106 | if self.block_ngram_repeat > 0: 107 | ngrams = [] 108 | le = len(self.next_ys) 109 | for j in range(self.next_ys[-1].size(0)): 110 | hyp, _ = self.get_hyp(le - 1, j) 111 | ngrams = set() 112 | fail = False 113 | gram = [] 114 | for i in range(le - 1): 115 | # Last n tokens, n = block_ngram_repeat 116 | gram = (gram + 117 | [hyp[i].item()])[-self.block_ngram_repeat:] 118 | # Skip the blocking if it is in the exclusion list 119 | if set(gram) & self.exclusion_tokens: 120 | continue 121 | if tuple(gram) in ngrams: 122 | fail = True 123 | ngrams.add(tuple(gram)) 124 | if fail: 125 | beam_scores[j] = -10e20 126 | else: 127 | beam_scores = word_probs[0] 128 | flat_beam_scores = beam_scores.view(-1) 129 | best_scores, best_scores_id = flat_beam_scores.topk(self.size, 0, 130 | True, True) 131 | 132 | self.all_scores.append(self.scores) 133 | self.scores = best_scores 134 | 135 | # best_scores_id is flattened beam x word array, so calculate which 136 | # word and beam each score came from 137 | prev_k = best_scores_id / num_words 138 | self.prev_ks.append(prev_k) 139 | self.next_ys.append((best_scores_id - prev_k * num_words)) 140 | self.attn.append(attn_out.index_select(0, prev_k)) 141 | self.hidden.append(hidden.index_select(0, prev_k)) 142 | self.global_scorer.update_global_state(self) 143 | 144 | for i in range(self.next_ys[-1].size(0)): 145 | if self.next_ys[-1][i] == self._eos: 146 | global_scores = self.global_scorer.score(self, self.scores) 147 | s = global_scores[i] 148 | self.finished.append((s, len(self.next_ys) - 1, i)) 149 | 150 | # End condition is when top-of-beam is EOS and no global score. 151 | if self.next_ys[-1][0] == self._eos: 152 | self.all_scores.append(self.scores) 153 | self.eos_top = True 154 | 155 | def done(self): 156 | return self.eos_top and len(self.finished) >= self.n_best 157 | 158 | def sort_finished(self, minimum=None): 159 | if minimum is not None: 160 | i = 0 161 | # Add from beam until we have minimum outputs. 162 | while len(self.finished) < minimum: 163 | global_scores = self.global_scorer.score(self, self.scores) 164 | s = global_scores[i] 165 | self.finished.append((s, len(self.next_ys) - 1, i)) 166 | i += 1 167 | 168 | self.finished.sort(key=lambda a: -a[0]) 169 | scores = [sc for sc, _, _ in self.finished] 170 | ks = [(t, k) for _, t, k in self.finished] 171 | return scores, ks 172 | 173 | def get_hyp(self, timestep, k): 174 | """ 175 | Walk back to construct the full hypothesis. 176 | """ 177 | hyp, attn, hidden = [], [], [] 178 | for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1): 179 | hyp.append(self.next_ys[j + 1][k]) 180 | attn.append(self.attn[j][k]) 181 | hidden.append(self.hidden[j][k]) 182 | k = self.prev_ks[j][k] 183 | return torch.stack(hyp[::-1]), torch.stack(attn[::-1]), torch.stack(hidden[::-1]) 184 | 185 | 186 | class GNMTGlobalScorer(object): 187 | """ 188 | NMT re-ranking score from 189 | "Google's Neural Machine Translation System" :cite:`wu2016google` 190 | Args: 191 | alpha (float): length parameter 192 | beta (float): coverage parameter 193 | """ 194 | 195 | def __init__(self, opt=None): 196 | self.alpha = 0. 197 | self.beta = 0. 198 | penalty_builder = PenaltyBuilder('none', 199 | 'avg') 200 | # Term will be subtracted from probability 201 | self.cov_penalty = penalty_builder.coverage_penalty() 202 | # Probability will be divided by this 203 | self.length_penalty = penalty_builder.length_penalty() 204 | 205 | def score(self, beam, logprobs): 206 | """ 207 | Rescores a prediction based on penalty functions 208 | """ 209 | normalized_probs = self.length_penalty(beam, 210 | logprobs, 211 | self.alpha) 212 | if not beam.stepwise_penalty: 213 | penalty = self.cov_penalty(beam, 214 | beam.global_state["coverage"], 215 | self.beta) 216 | normalized_probs -= penalty 217 | 218 | return normalized_probs 219 | 220 | def update_score(self, beam, attn): 221 | """ 222 | Function to update scores of a Beam that is not finished 223 | """ 224 | if "prev_penalty" in list(beam.global_state.keys()): 225 | beam.scores.add_(beam.global_state["prev_penalty"]) 226 | penalty = self.cov_penalty(beam, 227 | beam.global_state["coverage"] + attn, 228 | self.beta) 229 | beam.scores.sub_(penalty) 230 | 231 | def update_global_state(self, beam): 232 | "Keeps the coverage vector as sum of attentions" 233 | if len(beam.prev_ks) == 1: 234 | beam.global_state["prev_penalty"] = beam.scores.clone().fill_(0.0) 235 | beam.global_state["coverage"] = beam.attn[-1] 236 | self.cov_total = beam.attn[-1].sum(1) 237 | else: 238 | self.cov_total += torch.min(beam.attn[-1], 239 | beam.global_state['coverage']).sum(1) 240 | beam.global_state["coverage"] = beam.global_state["coverage"] \ 241 | .index_select(0, beam.prev_ks[-1]).add(beam.attn[-1]) 242 | 243 | prev_penalty = self.cov_penalty(beam, 244 | beam.global_state["coverage"], 245 | self.beta) 246 | beam.global_state["prev_penalty"] = prev_penalty -------------------------------------------------------------------------------- /fine-tune/model/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .basic_layers import ConvNorm, LinearNorm, ForwardAttentionV2, Prenet 6 | from .utils import get_mask_from_lengths 7 | 8 | 9 | class Decoder(nn.Module): 10 | def __init__(self, hparams): 11 | super(Decoder, self).__init__() 12 | self.n_mel_channels = hparams.n_mel_channels 13 | self.n_frames_per_step = hparams.n_frames_per_step_decoder 14 | self.hidden_cat_dim = hparams.encoder_embedding_dim + hparams.speaker_embedding_dim 15 | self.attention_rnn_dim = hparams.attention_rnn_dim 16 | self.decoder_rnn_dim = hparams.decoder_rnn_dim 17 | self.prenet_dim = hparams.prenet_dim 18 | self.max_decoder_steps = hparams.max_decoder_steps 19 | self.stop_threshold = hparams.stop_threshold 20 | self.feed_back_last = hparams.feed_back_last 21 | 22 | if hparams.feed_back_last: 23 | prenet_input_dim = hparams.n_mel_channels 24 | else: 25 | prenet_input_dim = hparams.n_mel_channels * hparams.n_frames_per_step_decoder 26 | 27 | self.prenet = Prenet( 28 | prenet_input_dim , 29 | hparams.prenet_dim) 30 | 31 | self.attention_rnn = nn.LSTMCell( 32 | hparams.prenet_dim[-1] + self.hidden_cat_dim, 33 | hparams.attention_rnn_dim) 34 | 35 | self.attention_layer = ForwardAttentionV2( 36 | hparams.attention_rnn_dim, 37 | self.hidden_cat_dim, 38 | hparams.attention_dim, hparams.attention_location_n_filters, 39 | hparams.attention_location_kernel_size) 40 | 41 | self.decoder_rnn = nn.LSTMCell( 42 | self.hidden_cat_dim + hparams.attention_rnn_dim, 43 | hparams.decoder_rnn_dim) 44 | 45 | self.linear_projection = LinearNorm( 46 | self.hidden_cat_dim + hparams.decoder_rnn_dim, 47 | hparams.n_mel_channels * hparams.n_frames_per_step_decoder) 48 | 49 | self.stop_layer = LinearNorm( 50 | self.hidden_cat_dim + hparams.decoder_rnn_dim, 1, 51 | bias=True, w_init_gain='sigmoid') 52 | 53 | def get_go_frame(self, memory): 54 | """ Gets all zeros frames to use as first decoder input 55 | PARAMS 56 | ------ 57 | memory: decoder outputs 58 | RETURNS 59 | ------- 60 | decoder_input: all zeros frames 61 | """ 62 | B = memory.size(0) 63 | if self.feed_back_last: 64 | input_dim = self.n_mel_channels 65 | else: 66 | input_dim = self.n_mel_channels * self.n_frames_per_step 67 | 68 | decoder_input = Variable(memory.data.new( 69 | B, input_dim).zero_()) 70 | return decoder_input 71 | 72 | def initialize_decoder_states(self, memory, mask): 73 | """ Initializes attention rnn states, decoder rnn states, attention 74 | weights, attention cumulative weights, attention context, stores memory 75 | and stores processed memory 76 | PARAMS 77 | ------ 78 | memory: Encoder outputs 79 | mask: Mask for padded data if training, expects None for inference 80 | """ 81 | B = memory.size(0) 82 | MAX_TIME = memory.size(1) 83 | 84 | self.attention_hidden = Variable(memory.data.new( 85 | B, self.attention_rnn_dim).zero_()) 86 | self.attention_cell = Variable(memory.data.new( 87 | B, self.attention_rnn_dim).zero_()) 88 | 89 | self.decoder_hidden = Variable(memory.data.new( 90 | B, self.decoder_rnn_dim).zero_()) 91 | self.decoder_cell = Variable(memory.data.new( 92 | B, self.decoder_rnn_dim).zero_()) 93 | 94 | self.attention_weights = Variable(memory.data.new( 95 | B, MAX_TIME).zero_()) 96 | self.attention_weights_cum = Variable(memory.data.new( 97 | B, MAX_TIME).zero_()) 98 | self.attention_context = Variable(memory.data.new( 99 | B, self.hidden_cat_dim).zero_()) 100 | 101 | self.log_alpha = Variable(memory.data.new(B, MAX_TIME).fill_(-float(1e20))) 102 | self.log_alpha[:, 0].fill_(0.) 103 | 104 | self.memory = memory 105 | self.processed_memory = self.attention_layer.memory_layer(memory) 106 | self.mask = mask 107 | 108 | def parse_decoder_inputs(self, decoder_inputs): 109 | """ Prepares decoder inputs, i.e. mel outputs 110 | PARAMS 111 | ------ 112 | decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs 113 | RETURNS 114 | ------- 115 | inputs: processed decoder inputs 116 | """ 117 | # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels) 118 | decoder_inputs = decoder_inputs.transpose(1, 2) 119 | decoder_inputs = decoder_inputs.reshape( 120 | decoder_inputs.size(0), 121 | int(decoder_inputs.size(1)/self.n_frames_per_step), -1) 122 | # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels) 123 | decoder_inputs = decoder_inputs.transpose(0, 1) 124 | if self.feed_back_last: 125 | decoder_inputs = decoder_inputs[:,:,-self.n_mel_channels:] 126 | 127 | return decoder_inputs 128 | 129 | def parse_decoder_outputs(self, mel_outputs, stop_outputs, alignments): 130 | """ Prepares decoder outputs for output 131 | PARAMS 132 | ------ 133 | mel_outputs: 134 | stop_outputs: stop output energies 135 | alignments: 136 | RETURNS 137 | ------- 138 | mel_outputs: 139 | stop_outpust: stop output energies 140 | alignments: 141 | """ 142 | # (T_out, B, MAX_TIME) -> (B, T_out, MAX_TIME) 143 | alignments = torch.stack(alignments).transpose(0, 1) 144 | # (T_out, B) -> (B, T_out) 145 | if alignments.size(0) == 1: 146 | stop_outputs = torch.stack(stop_outputs).unsqueeze(0) 147 | else: 148 | stop_outputs = torch.stack(stop_outputs).transpose(0, 1) 149 | 150 | stop_outputs = stop_outputs.contiguous() 151 | # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels) 152 | mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() 153 | # decouple frames per step 154 | mel_outputs = mel_outputs.view( 155 | mel_outputs.size(0), -1, self.n_mel_channels) 156 | # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out) 157 | mel_outputs = mel_outputs.transpose(1, 2) 158 | 159 | return mel_outputs, stop_outputs, alignments 160 | 161 | def attend(self, decoder_input): 162 | cell_input = torch.cat((decoder_input, self.attention_context), -1) 163 | self.attention_hidden, self.attention_cell = self.attention_rnn( 164 | cell_input, (self.attention_hidden, self.attention_cell)) 165 | 166 | attention_weights_cat = torch.cat( 167 | (self.attention_weights.unsqueeze(1), 168 | self.attention_weights_cum.unsqueeze(1)), dim=1) 169 | 170 | self.attention_context, self.attention_weights, self.log_alpha = self.attention_layer( 171 | self.attention_hidden, self.memory, self.processed_memory, 172 | attention_weights_cat, self.mask, self.log_alpha) 173 | 174 | self.attention_weights_cum += self.attention_weights 175 | 176 | decoder_rnn_input = torch.cat( 177 | (self.attention_hidden, self.attention_context), -1) 178 | 179 | return decoder_rnn_input, self.attention_context, self.attention_weights 180 | 181 | def decode(self, decoder_input): 182 | 183 | self.decoder_hidden, self.decoder_cell = self.decoder_rnn( 184 | decoder_input, (self.decoder_hidden, self.decoder_cell)) 185 | 186 | return self.decoder_hidden 187 | 188 | def forward(self, memory, decoder_inputs, memory_lengths): 189 | """ Decoder forward pass for training 190 | PARAMS 191 | ------ 192 | memory: Encoder outputs [B, encoder_max_time, hidden_dim] 193 | decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs [B, mel_bin, T] 194 | memory_lengths: Encoder output lengths for attention masking. [B] 195 | RETURNS 196 | ------- 197 | mel_outputs: mel outputs from the decoder [B, mel_bin, T] 198 | stop_outputs: stop outputs from the decoder [B, T/r] 199 | alignments: sequence of attention weights from the decoder [B, T/r, encoder_max_time] 200 | """ 201 | 202 | decoder_input = self.get_go_frame(memory).unsqueeze(0) 203 | decoder_inputs = self.parse_decoder_inputs(decoder_inputs) 204 | decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) 205 | decoder_inputs = self.prenet(decoder_inputs) # [T/r + 1, B, prenet_dim ] 206 | 207 | self.initialize_decoder_states( 208 | memory, mask=~get_mask_from_lengths(memory_lengths)) 209 | 210 | mel_outputs, stop_outputs, alignments = [], [], [] 211 | while len(mel_outputs) < decoder_inputs.size(0) - 1: 212 | decoder_input = decoder_inputs[len(mel_outputs)] 213 | 214 | decoder_rnn_input, context, attention_weights = self.attend(decoder_input) 215 | 216 | decoder_rnn_output = self.decode(decoder_rnn_input) 217 | 218 | decoder_hidden_attention_context = torch.cat( 219 | (decoder_rnn_output, context), dim=1) 220 | 221 | mel_output = self.linear_projection(decoder_hidden_attention_context) 222 | stop_output = self.stop_layer(decoder_hidden_attention_context) 223 | 224 | mel_outputs += [mel_output.squeeze(1)] #? perhaps don't need squeeze 225 | stop_outputs += [stop_output.squeeze()] 226 | alignments += [attention_weights] 227 | 228 | mel_outputs, stop_outputs, alignments = self.parse_decoder_outputs( 229 | mel_outputs, stop_outputs, alignments) 230 | 231 | return mel_outputs, stop_outputs, alignments 232 | 233 | def inference(self, memory): 234 | """ Decoder inference 235 | PARAMS 236 | ------ 237 | memory: Encoder outputs 238 | RETURNS 239 | ------- 240 | mel_outputs: mel outputs from the decoder 241 | stop_outputs: stop outputs from the decoder 242 | alignments: sequence of attention weights from the decoder 243 | """ 244 | decoder_input = self.get_go_frame(memory) 245 | 246 | self.initialize_decoder_states(memory, mask=None) 247 | 248 | mel_outputs, stop_outputs, alignments = [], [], [] 249 | while True: 250 | decoder_input = self.prenet(decoder_input) 251 | 252 | decoder_input_final, context, alignment = self.attend(decoder_input) 253 | 254 | #mel_output, stop_output, alignment = self.decode(decoder_input) 255 | decoder_rnn_output = self.decode(decoder_input_final) 256 | decoder_hidden_attention_context = torch.cat( 257 | (decoder_rnn_output, context), dim=1) 258 | 259 | mel_output = self.linear_projection(decoder_hidden_attention_context) 260 | stop_output = self.stop_layer(decoder_hidden_attention_context) 261 | 262 | mel_outputs += [mel_output.squeeze(1)] 263 | stop_outputs += [stop_output] 264 | alignments += [alignment] 265 | 266 | 267 | if torch.sigmoid(stop_output.data) > self.stop_threshold: 268 | break 269 | elif len(mel_outputs) == self.max_decoder_steps: 270 | print("Warning! Reached max decoder steps") 271 | break 272 | 273 | if self.feed_back_last: 274 | decoder_input = mel_output[:,-self.n_mel_channels:] 275 | else: 276 | decoder_input = mel_output 277 | 278 | mel_outputs, stop_outputs, alignments = self.parse_decoder_outputs( 279 | mel_outputs, stop_outputs, alignments) 280 | 281 | return mel_outputs, stop_outputs, alignments -------------------------------------------------------------------------------- /fine-tune/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .utils import get_mask_from_lengths 5 | 6 | class ParrotLoss(nn.Module): 7 | def __init__(self, hparams): 8 | super(ParrotLoss, self).__init__() 9 | 10 | self.L1Loss = nn.L1Loss(reduction='none') 11 | self.MSELoss = nn.MSELoss(reduction='none') 12 | self.BCEWithLogitsLoss = nn.BCEWithLogitsLoss(reduction='none') 13 | self.CrossEntropyLoss = nn.CrossEntropyLoss(reduction='none') 14 | self.n_frames_per_step = hparams.n_frames_per_step_decoder 15 | self.eos = hparams.n_symbols 16 | self.predict_spectrogram = hparams.predict_spectrogram 17 | 18 | self.contr_w = hparams.contrastive_loss_w 19 | self.spenc_w = hparams.speaker_encoder_loss_w 20 | self.texcl_w = hparams.text_classifier_loss_w 21 | self.spadv_w = hparams.speaker_adversial_loss_w 22 | self.spcla_w = hparams.speaker_classifier_loss_w 23 | 24 | self.speaker_A = hparams.speaker_A 25 | self.speaker_B = hparams.speaker_B 26 | 27 | def parse_targets(self, targets, text_lengths): 28 | ''' 29 | text_target [batch_size, text_len] 30 | mel_target [batch_size, mel_bins, T] 31 | spc_target [batch_size, spc_bins, T] 32 | speaker_target [batch_size] 33 | stop_target [batch_size, T] 34 | ''' 35 | text_target, mel_target, spc_target, speaker_target, stop_target = targets 36 | 37 | B = stop_target.size(0) 38 | stop_target = stop_target.reshape(B, -1, self.n_frames_per_step) 39 | stop_target = stop_target[:, :, 0] 40 | 41 | padded = torch.tensor(text_target.data.new(B,1).zero_()) 42 | text_target = torch.cat((text_target, padded), dim=-1) 43 | 44 | # adding the ending token for target 45 | for bid in range(B): 46 | text_target[bid, text_lengths[bid].item()] = self.eos 47 | 48 | return text_target, mel_target, spc_target, speaker_target, stop_target 49 | 50 | def forward(self, model_outputs, targets, eps=1e-5): 51 | 52 | ''' 53 | predicted_mel [batch_size, mel_bins, T] 54 | predicted_stop [batch_size, T/r] 55 | alignment 56 | when input_text==True [batch_size, T/r, max_text_len] 57 | when input_text==False [batch_size, T/r, T/r] 58 | text_hidden [B, max_text_len, hidden_dim] 59 | mel_hidden [B, max_text_len, hidden_dim] 60 | text_logit_from_mel_hidden [B, max_text_len+1, n_symbols+1] 61 | speaker_logit_from_mel [B, n_speakers] 62 | speaker_logit_from_mel_hidden [B, max_text_len, n_speakers] 63 | text_lengths [B,] 64 | mel_lengths [B,] 65 | ''' 66 | predicted_mel, post_output, predicted_stop, alignments,\ 67 | text_hidden, mel_hidden, text_logit_from_mel_hidden, \ 68 | audio_seq2seq_alignments, \ 69 | speaker_logit_from_mel_hidden, \ 70 | text_lengths, mel_lengths = model_outputs 71 | 72 | text_target, mel_target, spc_target, speaker_target, stop_target = self.parse_targets(targets, text_lengths) 73 | 74 | 75 | ## get masks ## 76 | mel_mask = get_mask_from_lengths(mel_lengths, mel_target.size(2)).unsqueeze(1).expand(-1, mel_target.size(1), -1).float() 77 | spc_mask = get_mask_from_lengths(mel_lengths, mel_target.size(2)).unsqueeze(1).expand(-1, spc_target.size(1), -1).float() 78 | 79 | mel_step_lengths = torch.ceil(mel_lengths.float() / self.n_frames_per_step).long() 80 | stop_mask = get_mask_from_lengths(mel_step_lengths, 81 | int(mel_target.size(2)/self.n_frames_per_step)).float() # [B, T/r] 82 | text_mask = get_mask_from_lengths(text_lengths).float() 83 | text_mask_plus_one = get_mask_from_lengths(text_lengths + 1).float() 84 | 85 | # reconstruction loss # 86 | recon_loss = torch.sum(self.L1Loss(predicted_mel, mel_target) * mel_mask) / torch.sum(mel_mask) 87 | 88 | if self.predict_spectrogram: 89 | recon_loss_post = (self.L1Loss(post_output, spc_target) * spc_mask).sum() / spc_mask.sum() 90 | else: 91 | recon_loss_post = (self.L1Loss(post_output, mel_target) * mel_mask).sum() / torch.sum(mel_mask) 92 | 93 | stop_loss = torch.sum(self.BCEWithLogitsLoss(predicted_stop, stop_target) * stop_mask) / torch.sum(stop_mask) 94 | 95 | 96 | if self.contr_w == 0.: 97 | contrast_loss = torch.tensor(0.).cuda() 98 | else: 99 | # contrastive mask # 100 | contrast_mask1 = get_mask_from_lengths(text_lengths).unsqueeze(2).expand(-1, -1, mel_hidden.size(1)) # [B, text_len] -> [B, text_len, T/r] 101 | contrast_mask2 = get_mask_from_lengths(text_lengths).unsqueeze(1).expand(-1, text_hidden.size(1), -1) # [B, T/r] -> [B, text_len, T/r] 102 | contrast_mask = (contrast_mask1 & contrast_mask2).float() 103 | 104 | text_hidden_normed = text_hidden / (torch.norm(text_hidden, dim=2, keepdim=True) + eps) 105 | mel_hidden_normed = mel_hidden / (torch.norm(mel_hidden, dim=2, keepdim=True) + eps) 106 | 107 | # (x - y) ** 2 = x ** 2 + y ** 2 - 2xy 108 | distance_matrix_xx = torch.sum(text_hidden_normed ** 2, dim=2, keepdim=True) #[batch_size, text_len, 1] 109 | distance_matrix_yy = torch.sum(mel_hidden_normed ** 2, dim=2) 110 | distance_matrix_yy = distance_matrix_yy.unsqueeze(1) #[batch_size, 1, text_len] 111 | 112 | #[batch_size, text_len, text_len] 113 | distance_matrix_xy = torch.bmm(text_hidden_normed, torch.transpose(mel_hidden_normed, 1, 2)) 114 | distance_matrix = distance_matrix_xx + distance_matrix_yy - 2 * distance_matrix_xy 115 | 116 | TTEXT = distance_matrix.size(1) 117 | hard_alignments = torch.eye(TTEXT).cuda() 118 | contrast_loss = hard_alignments * distance_matrix + \ 119 | (1. - hard_alignments) * torch.max(1. - distance_matrix, torch.zeros_like(distance_matrix)) 120 | 121 | contrast_loss = torch.sum(contrast_loss * contrast_mask) / torch.sum(contrast_mask) 122 | 123 | n_speakers = speaker_logit_from_mel_hidden.size(2) 124 | TTEXT = speaker_logit_from_mel_hidden.size(1) 125 | n_symbols_plus_one = text_logit_from_mel_hidden.size(2) 126 | 127 | speaker_encoder_loss = torch.tensor(0.).cuda() 128 | speaker_encoder_acc = torch.tensor(0.).cuda() 129 | 130 | 131 | speaker_logit_flatten = speaker_logit_from_mel_hidden.reshape(-1) # -> [B* TTEXT] 132 | predicted_speaker = (F.sigmoid(speaker_logit_flatten) > 0.5).long() 133 | speaker_target_flatten = speaker_target.unsqueeze(1).expand(-1, TTEXT).reshape(-1) 134 | 135 | speaker_classification_acc = ((predicted_speaker == speaker_target_flatten).float() * text_mask.reshape(-1)).sum() / text_mask.sum() 136 | loss = self.BCEWithLogitsLoss(speaker_logit_flatten, speaker_target_flatten.float()) 137 | 138 | 139 | speaker_classification_loss = torch.sum(loss * text_mask.reshape(-1)) / torch.sum(text_mask) 140 | 141 | # text classification loss # 142 | text_logit_flatten = text_logit_from_mel_hidden.reshape(-1, n_symbols_plus_one) 143 | text_target_flatten = text_target.reshape(-1) 144 | _, predicted_text = torch.max(text_logit_flatten, dim=1) 145 | text_classification_acc = ((predicted_text == text_target_flatten).float()*text_mask_plus_one.reshape(-1)).sum()/text_mask_plus_one.sum() 146 | loss = self.CrossEntropyLoss(text_logit_flatten, text_target_flatten) 147 | text_classification_loss = torch.sum(loss * text_mask_plus_one.reshape(-1)) / torch.sum(text_mask_plus_one) 148 | 149 | # speaker adversival loss # 150 | flatten_target = 0.5 * torch.ones_like(speaker_logit_flatten) 151 | loss = self.MSELoss(F.sigmoid(speaker_logit_flatten), flatten_target) 152 | mask = text_mask.reshape(-1) 153 | speaker_adversial_loss = torch.sum(loss * mask) / torch.sum(mask) 154 | 155 | loss_list = [recon_loss, recon_loss_post, stop_loss, 156 | contrast_loss, speaker_encoder_loss, speaker_classification_loss, 157 | text_classification_loss, speaker_adversial_loss] 158 | 159 | acc_list = [speaker_encoder_acc, speaker_classification_acc, text_classification_acc] 160 | 161 | combined_loss1 = recon_loss + recon_loss_post + stop_loss + self.contr_w * contrast_loss + \ 162 | self.texcl_w * text_classification_loss + \ 163 | self.spadv_w * speaker_adversial_loss 164 | 165 | combined_loss2 = self.spcla_w * speaker_classification_loss 166 | 167 | 168 | 169 | return loss_list, acc_list, combined_loss1, combined_loss2 170 | 171 | 172 | def torch_test_grad(): 173 | 174 | x = torch.ones((1,1)) 175 | 176 | net1 = nn.Linear(1, 1, bias=False) 177 | 178 | 179 | net1.weight.data.fill_(2.) 180 | net2 = nn.Linear(1, 1, bias=False) 181 | net2.weight.data.fill_(3.) 182 | 183 | all_params = [] 184 | 185 | all_params.extend([p for p in net1.parameters()]) 186 | all_params.extend([p for p in net2.parameters()]) 187 | #print all_params 188 | 189 | y = net1(x) ** 2 190 | 191 | z = net2(y) ** 2 192 | 193 | loss1 = (z - 0.) 194 | loss2 = -5. * (z - 0.) ** 2 195 | 196 | 197 | for p in net2.parameters(): 198 | p.requires_grad = False 199 | 200 | loss1.backward(retain_graph=True) 201 | 202 | 203 | print((net1.weight.grad)) 204 | print((net2.weight.grad)) 205 | 206 | opt = torch.optim.SGD(all_params, lr=0.1) 207 | opt.step() 208 | 209 | print((net1.weight)) 210 | print((net2.weight)) 211 | #net1.weight.data = net1.weight.data - 0.1 * net1.weight.grad.data 212 | 213 | for p in net2.parameters(): 214 | p.requires_grad=True 215 | 216 | for p in net1.parameters(): 217 | p.requires_grad=False 218 | 219 | loss2.backward() 220 | print((net1.weight)) 221 | print((net2.weight.grad)) 222 | print((net1.weight.grad)) 223 | 224 | net1.zero_grad() 225 | print((net1.weight.grad)) 226 | 227 | def test_logic(): 228 | a = torch.ByteTensor([1,0,0,0,0]) 229 | b = torch.ByteTensor([1,1,1,0,0]) 230 | 231 | print(~a) 232 | print(a & b) 233 | print(a | b) 234 | 235 | text_lengths = torch.IntTensor([2,4,3]).cuda() 236 | mel_hidden_lengths =torch.IntTensor([5,6,5]).cuda() 237 | contrast_mask1 = get_mask_from_lengths(text_lengths).unsqueeze(2).expand(-1, -1, 6) # [B, text_len] -> [B, text_len, T/r] 238 | contrast_mask2 = get_mask_from_lengths(mel_hidden_lengths).unsqueeze(1).expand(-1, 4, -1) # [B, T/r] -> [B, text_len, T/r] 239 | contrast_mask = contrast_mask1 & contrast_mask2 240 | print(contrast_mask) 241 | 242 | if __name__ == '__main__': 243 | 244 | torch_test_grad() 245 | 246 | -------------------------------------------------------------------------------- /fine-tune/model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | from torch.autograd import Variable 5 | from math import sqrt 6 | from .utils import to_gpu 7 | from .decoder import Decoder 8 | from .layers import SpeakerClassifier, SpeakerEncoder, AudioSeq2seq, TextEncoder, PostNet, MergeNet 9 | 10 | 11 | class Parrot(nn.Module): 12 | def __init__(self, hparams): 13 | super(Parrot, self).__init__() 14 | 15 | #print hparams 16 | # plus 17 | self.embedding = nn.Embedding( 18 | hparams.n_symbols + 1, hparams.symbols_embedding_dim) 19 | std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) 20 | val = sqrt(3.0) * std 21 | 22 | self.sos = hparams.n_symbols 23 | 24 | self.embedding.weight.data.uniform_(-val, val) 25 | 26 | self.text_encoder = TextEncoder(hparams) 27 | 28 | self.audio_seq2seq = AudioSeq2seq(hparams) 29 | 30 | self.merge_net = MergeNet(hparams) 31 | 32 | self.speaker_encoder = SpeakerEncoder(hparams) 33 | 34 | self.speaker_classifier = SpeakerClassifier(hparams) 35 | 36 | self.decoder = Decoder(hparams) 37 | 38 | self.postnet = PostNet(hparams) 39 | 40 | self._initilize_emb(hparams) 41 | 42 | self.spemb_input = hparams.spemb_input 43 | 44 | def _initilize_emb(self, hparams): 45 | 46 | a_embedding = np.load(hparams.a_embedding_path) 47 | a_embedding = np.mean(a_embedding, axis=0) 48 | 49 | b_embedding = np.load(hparams.b_embedding_path) 50 | b_embedding = np.mean(b_embedding, axis=0) 51 | 52 | self.sp_embedding = nn.Embedding( 53 | hparams.n_speakers, hparams.speaker_embedding_dim) 54 | 55 | self.sp_embedding.weight.data[0] = torch.FloatTensor(a_embedding) 56 | self.sp_embedding.weight.data[1] = torch.FloatTensor(b_embedding) 57 | 58 | def grouped_parameters(self,): 59 | 60 | params_group1 = [p for p in self.embedding.parameters()] 61 | params_group1.extend([p for p in self.text_encoder.parameters()]) 62 | params_group1.extend([p for p in self.audio_seq2seq.parameters()]) 63 | 64 | params_group1.extend([p for p in self.sp_embedding.parameters()]) 65 | params_group1.extend([p for p in self.merge_net.parameters()]) 66 | params_group1.extend([p for p in self.decoder.parameters()]) 67 | params_group1.extend([p for p in self.postnet.parameters()]) 68 | 69 | return params_group1, [p for p in self.speaker_classifier.parameters()] 70 | 71 | def parse_batch(self, batch): 72 | text_input_padded, mel_padded, spc_padded, speaker_id, \ 73 | text_lengths, mel_lengths, stop_token_padded = batch 74 | 75 | text_input_padded = to_gpu(text_input_padded).long() 76 | mel_padded = to_gpu(mel_padded).float() 77 | spc_padded = to_gpu(spc_padded).float() 78 | speaker_id = to_gpu(speaker_id).long() 79 | text_lengths = to_gpu(text_lengths).long() 80 | mel_lengths = to_gpu(mel_lengths).long() 81 | stop_token_padded = to_gpu(stop_token_padded).float() 82 | 83 | return ((text_input_padded, mel_padded, text_lengths, mel_lengths, speaker_id), 84 | (text_input_padded, mel_padded, spc_padded, speaker_id, stop_token_padded)) 85 | 86 | 87 | def forward(self, inputs, input_text): 88 | 89 | text_input_padded, mel_padded, text_lengths, mel_lengths, speaker_id = inputs 90 | 91 | text_input_embedded = self.embedding(text_input_padded.long()).transpose(1, 2) # -> [B, text_embedding_dim, max_text_len] 92 | text_hidden = self.text_encoder(text_input_embedded, text_lengths) # -> [B, max_text_len, hidden_dim] 93 | 94 | B = text_input_padded.size(0) 95 | start_embedding = Variable(text_input_padded.data.new(B,).fill_(self.sos)) 96 | start_embedding = self.embedding(start_embedding) 97 | 98 | speaker_embedding = self.sp_embedding(speaker_id) 99 | 100 | if self.spemb_input: 101 | T = mel_padded.size(2) 102 | audio_input = torch.cat((mel_padded, 103 | speaker_embedding.detach().unsqueeze(2).expand(-1, -1, T)), dim=1) 104 | else: 105 | audio_input = mel_padded 106 | 107 | #-> [B, text_len+1, hidden_dim] [B, text_len+1, n_symbols] [B, text_len+1, T/r] 108 | audio_seq2seq_hidden, audio_seq2seq_logit, audio_seq2seq_alignments = self.audio_seq2seq( 109 | audio_input, mel_lengths, text_input_embedded, start_embedding) 110 | audio_seq2seq_hidden= audio_seq2seq_hidden[:,:-1, :] # -> [B, text_len, hidden_dim] 111 | 112 | speaker_logit_from_mel_hidden = self.speaker_classifier(audio_seq2seq_hidden) # -> [B, text_len, n_speakers] 113 | 114 | if input_text: 115 | hidden = self.merge_net(text_hidden, text_lengths) 116 | else: 117 | hidden = self.merge_net(audio_seq2seq_hidden, text_lengths) 118 | 119 | L = hidden.size(1) 120 | hidden = torch.cat([hidden, speaker_embedding.unsqueeze(1).expand(-1, L, -1)], -1) 121 | 122 | predicted_mel, predicted_stop, alignments = self.decoder(hidden, mel_padded, text_lengths) 123 | 124 | post_output = self.postnet(predicted_mel) 125 | 126 | outputs = [predicted_mel, post_output, predicted_stop, alignments, 127 | text_hidden, audio_seq2seq_hidden, audio_seq2seq_logit, audio_seq2seq_alignments, 128 | speaker_logit_from_mel_hidden, 129 | text_lengths, mel_lengths] 130 | 131 | return outputs 132 | 133 | 134 | def inference(self, inputs, input_text, id_reference, beam_width): 135 | 136 | text_input_padded, mel_padded, text_lengths, mel_lengths, speaker_id = inputs 137 | text_input_embedded = self.embedding(text_input_padded.long()).transpose(1, 2) 138 | text_hidden = self.text_encoder.inference(text_input_embedded) 139 | 140 | B = text_input_padded.size(0) # B should be 1 141 | start_embedding = Variable(text_input_padded.data.new(B,).fill_(self.sos)) 142 | start_embedding = self.embedding(start_embedding) # [1, embedding_dim] 143 | 144 | #-> [B, text_len+1, hidden_dim] [B, text_len+1, n_symbols] [B, text_len+1, T/r] 145 | 146 | speaker_embedding = self.sp_embedding(speaker_id) 147 | 148 | if self.spemb_input: 149 | T = mel_padded.size(2) 150 | audio_input = torch.cat((mel_padded, 151 | speaker_embedding.unsqueeze(2).expand(-1,-1,T)), dim=1) 152 | else: 153 | audio_input = mel_padded 154 | 155 | audio_seq2seq_hidden, audio_seq2seq_phids, audio_seq2seq_alignments = self.audio_seq2seq.inference_beam( 156 | audio_input, start_embedding, self.embedding, beam_width=beam_width) 157 | audio_seq2seq_hidden= audio_seq2seq_hidden[:,:-1, :] # -> [B, text_len, hidden_dim] 158 | 159 | speaker_embedding = self.sp_embedding(id_reference) 160 | 161 | if input_text: 162 | hidden = self.merge_net.inference(text_hidden) 163 | else: 164 | hidden = self.merge_net.inference(audio_seq2seq_hidden) 165 | 166 | L = hidden.size(1) 167 | hidden = torch.cat([hidden, speaker_embedding.unsqueeze(1).expand(-1, L, -1)], -1) 168 | 169 | predicted_mel, predicted_stop, alignments = self.decoder.inference(hidden) 170 | 171 | post_output = self.postnet(predicted_mel) 172 | 173 | return (predicted_mel, post_output, predicted_stop, alignments, 174 | text_hidden, audio_seq2seq_hidden, audio_seq2seq_phids, audio_seq2seq_alignments, 175 | speaker_id) 176 | 177 | 178 | -------------------------------------------------------------------------------- /fine-tune/model/penalties.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | 5 | class PenaltyBuilder(object): 6 | """ 7 | Returns the Length and Coverage Penalty function for Beam Search. 8 | Args: 9 | length_pen (str): option name of length pen 10 | cov_pen (str): option name of cov pen 11 | """ 12 | 13 | def __init__(self, cov_pen, length_pen): 14 | self.length_pen = length_pen 15 | self.cov_pen = cov_pen 16 | 17 | def coverage_penalty(self): 18 | if self.cov_pen == "wu": 19 | return self.coverage_wu 20 | elif self.cov_pen == "summary": 21 | return self.coverage_summary 22 | else: 23 | return self.coverage_none 24 | 25 | def length_penalty(self): 26 | if self.length_pen == "wu": 27 | return self.length_wu 28 | elif self.length_pen == "avg": 29 | return self.length_average 30 | else: 31 | return self.length_none 32 | 33 | """ 34 | Below are all the different penalty terms implemented so far 35 | """ 36 | 37 | def coverage_wu(self, beam, cov, beta=0.): 38 | """ 39 | NMT coverage re-ranking score from 40 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 41 | """ 42 | penalty = -torch.min(cov, cov.clone().fill_(1.0)).log().sum(1) 43 | return beta * penalty 44 | 45 | def coverage_summary(self, beam, cov, beta=0.): 46 | """ 47 | Our summary penalty. 48 | """ 49 | penalty = torch.max(cov, cov.clone().fill_(1.0)).sum(1) 50 | penalty -= cov.size(1) 51 | return beta * penalty 52 | 53 | def coverage_none(self, beam, cov, beta=0.): 54 | """ 55 | returns zero as penalty 56 | """ 57 | return beam.scores.clone().fill_(0.0) 58 | 59 | def length_wu(self, beam, logprobs, alpha=0.): 60 | """ 61 | NMT length re-ranking score from 62 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 63 | """ 64 | 65 | modifier = (((5 + len(beam.next_ys)) ** alpha) / 66 | ((5 + 1) ** alpha)) 67 | return (logprobs / modifier) 68 | 69 | def length_average(self, beam, logprobs, alpha=0.): 70 | """ 71 | Returns the average probability of tokens in a sequence. 72 | """ 73 | return logprobs / len(beam.next_ys) 74 | 75 | def length_none(self, beam, logprobs, alpha=0., beta=0.): 76 | """ 77 | Returns unmodified scores. 78 | """ 79 | return logprobs -------------------------------------------------------------------------------- /fine-tune/model/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def gcd(a,b): 6 | a, b = (a, b) if a >=b else (b, a) 7 | if a%b == 0: 8 | return b 9 | else : 10 | return gcd(b,a%b) 11 | 12 | def lcm(a,b): 13 | return a*b//gcd(a,b) 14 | 15 | 16 | if __name__ == "__main__": 17 | print(lcm(3,2)) 18 | 19 | def get_mask_from_lengths(lengths, max_len=None): 20 | if max_len is None: 21 | max_len = torch.max(lengths).item() 22 | ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) 23 | #print ids 24 | mask = (ids < lengths.unsqueeze(1)).byte() 25 | return mask 26 | 27 | def to_gpu(x): 28 | x = x.contiguous() 29 | 30 | if torch.cuda.is_available(): 31 | x = x.cuda(non_blocking=True) 32 | return torch.autograd.Variable(x) 33 | 34 | def test_mask(): 35 | lengths = torch.IntTensor([3,5,4]) 36 | print(torch.ceil(lengths.float() / 2)) 37 | 38 | data = torch.FloatTensor(3, 5, 2) # [B, T, D] 39 | data.fill_(1.) 40 | m = get_mask_from_lengths(lengths.cuda(), data.size(1)) 41 | print(m) 42 | m = m.unsqueeze(2).expand(-1,-1,data.size(2)).float() 43 | print(m) 44 | 45 | print(torch.sum(data.cuda() * m) / torch.sum(m)) 46 | 47 | 48 | def test_loss(): 49 | data1 = torch.FloatTensor(3, 5, 2) 50 | data1.fill_(1.) 51 | data2 = torch.FloatTensor(3, 5, 2) 52 | data2.fill_(2.) 53 | data2[0,0,0] = 1000 54 | 55 | l = torch.nn.L1Loss(reduction='none')(data1,data2) 56 | print(l) 57 | 58 | 59 | #if __name__ == '__main__': 60 | # test_mask() -------------------------------------------------------------------------------- /fine-tune/multiproc.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import sys 4 | import subprocess 5 | 6 | argslist = list(sys.argv)[1:] 7 | num_gpus = torch.cuda.device_count() 8 | argslist.append('--n_gpus={}'.format(num_gpus)) 9 | workers = [] 10 | job_id = time.strftime("%Y_%m_%d-%H%M%S") 11 | argslist.append("--group_name=group_{}".format(job_id)) 12 | 13 | for i in range(num_gpus): 14 | argslist.append('--rank={}'.format(i)) 15 | stdout = None if i == 0 else open("logs/{}_GPU_{}.log".format(job_id, i), 16 | "w") 17 | print(argslist) 18 | p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout) 19 | workers.append(p) 20 | argslist = argslist[:-1] 21 | 22 | for p in workers: 23 | p.wait() 24 | -------------------------------------------------------------------------------- /fine-tune/plotting_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import matplotlib.pylab as plt 4 | import numpy as np 5 | 6 | 7 | def save_figure_to_numpy(fig): 8 | # save it to a numpy array. 9 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 10 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 11 | return data 12 | 13 | def plot_alignment(alignment, fn): 14 | # [4, encoder_step, decoder_step] 15 | fig, axes = plt.subplots(1, 2) 16 | 17 | for j in range(2): 18 | g = axes[j].imshow(alignment[j,:,:].T, 19 | aspect='auto', origin='lower', 20 | interpolation='none') 21 | plt.colorbar(g, ax=axes[j]) 22 | 23 | plt.savefig(fn) 24 | plt.close() 25 | return fn 26 | 27 | 28 | def plot_alignment_to_numpy(alignment, info=None): 29 | fig, ax = plt.subplots(figsize=(6, 4)) 30 | im = ax.imshow(alignment, aspect='auto', origin='lower', 31 | interpolation='none') 32 | fig.colorbar(im, ax=ax) 33 | xlabel = 'Decoder timestep' 34 | if info is not None: 35 | xlabel += '\n\n' + info 36 | plt.xlabel(xlabel) 37 | plt.ylabel('Encoder timestep') 38 | plt.tight_layout() 39 | 40 | fig.canvas.draw() 41 | data = save_figure_to_numpy(fig) 42 | plt.close() 43 | return data 44 | 45 | 46 | def plot_spectrogram_to_numpy(spectrogram): 47 | fig, ax = plt.subplots(figsize=(12, 3)) 48 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 49 | interpolation='none') 50 | plt.colorbar(im, ax=ax) 51 | plt.xlabel("Frames") 52 | plt.ylabel("Channels") 53 | plt.tight_layout() 54 | 55 | fig.canvas.draw() 56 | data = save_figure_to_numpy(fig) 57 | plt.close() 58 | return data 59 | 60 | 61 | def plot_gate_outputs_to_numpy(gate_targets, gate_outputs): 62 | fig, ax = plt.subplots(figsize=(12, 3)) 63 | ax.scatter(list(range(len(gate_targets))), gate_targets, alpha=0.5, 64 | color='green', marker='+', s=1, label='target') 65 | ax.scatter(list(range(len(gate_outputs))), gate_outputs, alpha=0.5, 66 | color='red', marker='.', s=1, label='predicted') 67 | 68 | plt.xlabel("Frames (Green target, Red predicted)") 69 | plt.ylabel("Gate State") 70 | plt.tight_layout() 71 | 72 | fig.canvas.draw() 73 | data = save_figure_to_numpy(fig) 74 | plt.close() 75 | return data 76 | -------------------------------------------------------------------------------- /fine-tune/reader/__init__.py: -------------------------------------------------------------------------------- 1 | from .reader import TextMelIDLoader, TextMelIDCollate 2 | from .symbols import sp2id, id2sp, id2ph -------------------------------------------------------------------------------- /fine-tune/reader/reader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import random 4 | import numpy as np 5 | from .symbols import ph2id 6 | from torch.utils.data import DataLoader 7 | 8 | def read_text(fn): 9 | text = [] 10 | with open(fn) as f: 11 | lines = f.readlines() 12 | for line in lines: 13 | start, end, phone = line.strip().split() 14 | text.append([int(start), int(end), phone]) 15 | return text 16 | 17 | class TextMelIDLoader(torch.utils.data.Dataset): 18 | 19 | def __init__(self, list_file, mean_std_file, speaker_A, speaker_B, shuffle=True, pids=None): 20 | 21 | file_path_list = [] 22 | 23 | with open(list_file) as f: 24 | lines = f.readlines() 25 | for line in lines: 26 | path, n_frame, n_text = line.strip().split() 27 | speaker_id = path.split('/')[-3].split('_')[2] 28 | 29 | if not pids is None: 30 | if not speaker_id in pids: 31 | 32 | continue 33 | 34 | if int(n_frame) >= 1000: 35 | continue 36 | 37 | file_path_list.append(path) 38 | 39 | 40 | random.seed(1234) 41 | if shuffle: 42 | random.shuffle(file_path_list) 43 | 44 | self.file_path_list = file_path_list 45 | 46 | self.mel_mean_std = np.float32(np.load(mean_std_file)) 47 | self.spc_mean_std = np.float32(np.load(mean_std_file.replace('mel', 'spec'))) 48 | self.sp2id = {speaker_A:0,speaker_B:1} 49 | 50 | 51 | def get_path_id(self, path): 52 | # Custom this function to obtain paths and speaker id 53 | # Deduce filenames 54 | text_path = path.replace('spec', 'text').replace('npy', 'txt').replace('log-', '') 55 | mel_path = path.replace('spec', 'mel') 56 | speaker_id = path.split('/')[-3].split('_')[2] 57 | # use non-trimed version # 58 | spec_path = path.replace('spec_trim', 'spec') 59 | text_path = text_path.replace('text_trim', 'text') 60 | mel_path = mel_path.replace('mel_trim', 'mel') 61 | speaker_id = path.split('/')[-3].split('_')[2] 62 | 63 | return mel_path, spec_path, text_path, speaker_id 64 | 65 | def get_text_mel_id_pair(self, path): 66 | ''' 67 | text_input [len_text] 68 | text_targets [len_mel] 69 | mel [mel_bin, len_mel] 70 | speaker_id [1] 71 | ''' 72 | 73 | mel_path, spec_path, text_path, speaker_id = self.get_path_id(path) 74 | # Load data from disk 75 | text_input = self.get_text(text_path) 76 | mel = np.load(mel_path) 77 | spc = np.load(spec_path) 78 | speaker_id = [self.sp2id[speaker_id]] 79 | # Normalize audio 80 | mel = (mel - self.mel_mean_std[0])/ self.mel_mean_std[1] 81 | spc = (spc - self.spc_mean_std[0]) / self.spc_mean_std[1] 82 | # Format for pytorch 83 | text_input = torch.LongTensor(text_input) 84 | mel = torch.from_numpy(np.transpose(mel)) 85 | spc = torch.from_numpy(np.transpose(spc)) 86 | speaker_id = torch.LongTensor(speaker_id) 87 | 88 | return (text_input, mel, spc, speaker_id) 89 | 90 | def get_text(self,text_path): 91 | 92 | text = read_text(text_path) 93 | text_input = [] 94 | 95 | for start, end, ph in text: 96 | dur = int((end - start) / 125000. + 0.6) 97 | text_input.append(ph2id[ph]) 98 | 99 | return text_input 100 | 101 | def __getitem__(self, index): 102 | return self.get_text_mel_id_pair(self.file_path_list[index]) 103 | 104 | def __len__(self): 105 | return len(self.file_path_list) 106 | 107 | 108 | class TextMelIDCollate(): 109 | 110 | def __init__(self, n_frames_per_step=2): 111 | self.n_frames_per_step = n_frames_per_step 112 | 113 | def __call__(self, batch): 114 | ''' 115 | batch is list of (text_input, mel, spc, speaker_id) 116 | ''' 117 | 118 | text_lengths = torch.IntTensor([len(x[0]) for x in batch]) 119 | mel_lengths = torch.IntTensor([x[1].size(1) for x in batch]) 120 | mel_bin = batch[0][1].size(0) 121 | spc_bin = batch[0][2].size(0) 122 | 123 | max_text_len = torch.max(text_lengths).item() 124 | max_mel_len = torch.max(mel_lengths).item() 125 | if max_mel_len % self.n_frames_per_step != 0: 126 | max_mel_len += self.n_frames_per_step - max_mel_len % self.n_frames_per_step 127 | assert max_mel_len % self.n_frames_per_step == 0 128 | 129 | text_input_padded = torch.LongTensor(len(batch), max_text_len) 130 | mel_padded = torch.FloatTensor(len(batch), mel_bin, max_mel_len) 131 | spc_padded = torch.FloatTensor(len(batch), spc_bin, max_mel_len) 132 | 133 | speaker_id = torch.LongTensor(len(batch)) 134 | stop_token_padded = torch.FloatTensor(len(batch), max_mel_len) 135 | 136 | text_input_padded.zero_() 137 | mel_padded.zero_() 138 | spc_padded.zero_() 139 | speaker_id.zero_() 140 | stop_token_padded.zero_() 141 | 142 | for i in range(len(batch)): 143 | text = batch[i][0] 144 | mel = batch[i][1] 145 | spc = batch[i][2] 146 | 147 | text_input_padded[i,:text.size(0)] = text 148 | mel_padded[i, :, :mel.size(1)] = mel 149 | spc_padded[i, :, :spc.size(1)] = spc 150 | speaker_id[i] = batch[i][3][0] 151 | #make sure the downsampled stop_token_padded have the last eng flag 1. 152 | stop_token_padded[i, mel.size(1)-self.n_frames_per_step:] = 1 153 | 154 | 155 | return text_input_padded, mel_padded, spc_padded, speaker_id, \ 156 | text_lengths, mel_lengths, stop_token_padded 157 | -------------------------------------------------------------------------------- /fine-tune/reader/symbols.py: -------------------------------------------------------------------------------- 1 | phone_list = ['pau', 'iy', 'aa', 'ch', 'ae', 'eh', 2 | 'ah', 'ao', 'ih', 'ey', 'aw', 3 | 'ay', 'ax', 'er', 'ng', 4 | 'sh', 'th', 'uh', 'zh', 'oy', 5 | 'dh', 'y', 'hh', 'jh', 'b', 6 | 'd', 'g', 'f', 'k', 'm', 7 | 'l', 'n', 'p', 's', 'r', 8 | 't', 'w', 'v', 'ow', 'z', 9 | 'uw', 'SOS/EOS'] 10 | 11 | seen_speakers = ['p336', 'p240', 'p262', 'p333', 'p297', 'p339', 'p276', 'p269', 'p303', 'p260', 'p250', 'p345', 'p305', 'p283', 'p277', 'p302', 'p280', 'p295', 'p245', 'p227', 'p257', 'p282', 'p259', 'p311', 'p301', 'p265', 'p270', 'p329', 'p362', 'p343', 'p246', 'p247', 'p351', 'p263', 'p363', 'p249', 'p231', 'p292', 'p304', 'p347', 'p314', 'p244', 'p261', 'p298', 'p272', 'p308', 'p299', 'p234', 'p268', 'p271', 'p316', 'p287', 'p318', 'p264', 'p313', 'p236', 'p238', 'p334', 'p312', 'p230', 'p253', 'p323', 'p361', 'p275', 'p252', 'p374', 'p286', 'p274', 'p254', 'p310', 'p306', 'p294', 'p326', 'p225', 'p255', 'p293', 'p278', 'p266', 'p229', 'p335', 'p281', 'p307', 'p256', 'p243', 'p364', 'p239', 'p232', 'p258', 'p267', 'p317', 'p284', 'p300', 'p288', 'p341', 'p340', 'p279', 'p330', 'p360', 'p285'] 12 | 13 | ph2id = {ph:i for i, ph in enumerate(phone_list)} 14 | ph2id['ssil'] = ph2id['pau'] 15 | sp2id = {sp:i for i, sp in enumerate(seen_speakers)} 16 | id2ph = {i:ph for i, ph in enumerate(phone_list)} 17 | id2sp = {i:sp for i, sp in enumerate(seen_speakers)} 18 | -------------------------------------------------------------------------------- /fine-tune/run.sh: -------------------------------------------------------------------------------- 1 | # slt to rms, rms to slt # 2 | RUN_EMB=true 3 | RUN_TRAIN=true 4 | RUN_GEN=true 5 | 6 | export CUDA_VISIBLE_DEVICES=2 7 | speaker_A='slt' 8 | speaker_B='rms' 9 | training_list="/home/jxzhang/Documents/DataSets/cmu_us_slt_arctic-0.95-release/list/train_non-parallel_${speaker_A}_${speaker_B}.list" 10 | validation_list="/home/jxzhang/Documents/DataSets/cmu_us_slt_arctic-0.95-release/list/eval_${speaker_A}_${speaker_B}.list" 11 | 12 | logdir="logdir_${speaker_A}_${speaker_B}" 13 | pretrain_checkpoint_path='../pre-train/outdir/checkpoint_0' 14 | finetune_ckpt="checkpoint_100" 15 | 16 | contrastive_loss_w=30.0 17 | speaker_adversial_loss_w=0.2 18 | speaker_classifier_loss_w=1.0 19 | decay_every=7 20 | warmup=7 21 | epochs=70 22 | batch_size=8 23 | SC_kernel_size=1 24 | learning_rate=1e-3 25 | gen_num=66 26 | 27 | if $RUN_EMB 28 | then 29 | echo 'running embeddings...' 30 | python inference_embedding.py \ 31 | -c $pretrain_checkpoint_path \ 32 | --hparams=speaker_A=$speaker_A,\ 33 | speaker_B=$speaker_B,\ 34 | training_list=${training_list},SC_kernel_size=$SC_kernel_size 35 | fi 36 | 37 | if $RUN_TRAIN 38 | then 39 | echo 'running trainings...' 40 | python train.py \ 41 | -l $logdir -o outdir --n_gpus=1 \ 42 | -c $pretrain_checkpoint_path \ 43 | --warm_start \ 44 | --hparams=speaker_A=$speaker_A,\ 45 | speaker_B=$speaker_B,a_embedding_path="outdir/embeddings/${speaker_A}.npy",\ 46 | b_embedding_path="outdir/embeddings/${speaker_B}.npy",\ 47 | training_list=$training_list,\ 48 | validation_list=$validation_list,\ 49 | contrastive_loss_w=$contrastive_loss_w,\ 50 | speaker_adversial_loss_w=$speaker_adversial_loss_w,\ 51 | speaker_classifier_loss_w=$speaker_classifier_loss_w,\ 52 | decay_every=$decay_every,\ 53 | epochs=$epochs,\ 54 | warmup=$warmup,batch_size=$batch_size,\ 55 | SC_kernel_size=$SC_kernel_size,learning_rate=$learning_rate 56 | fi 57 | 58 | 59 | if $RUN_GEN 60 | then 61 | echo 'running generations...' 62 | python inference.py \ 63 | -c outdir/$logdir/$finetune_ckpt \ 64 | --num $gen_num \ 65 | --hparams=validation_list=$validation_list,SC_kernel_size=$SC_kernel_size 66 | fi 67 | -------------------------------------------------------------------------------- /fine-tune/zero_embeddings.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxzhanggg/nonparaSeq2seqVC_code/4c03a6be3bc76207b7cf8222c985dc85c7018cde/fine-tune/zero_embeddings.npy -------------------------------------------------------------------------------- /pre-train/distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.nn.modules import Module 4 | from torch.autograd import Variable 5 | 6 | def _flatten_dense_tensors(tensors): 7 | """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of 8 | same dense type. 9 | Since inputs are dense, the resulting tensor will be a concatenated 1D 10 | buffer. Element-wise operation on this buffer will be equivalent to 11 | operating individually. 12 | Arguments: 13 | tensors (Iterable[Tensor]): dense tensors to flatten. 14 | Returns: 15 | A contiguous 1D buffer containing input tensors. 16 | """ 17 | if len(tensors) == 1: 18 | return tensors[0].contiguous().view(-1) 19 | flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) 20 | return flat 21 | 22 | def _unflatten_dense_tensors(flat, tensors): 23 | """View a flat buffer using the sizes of tensors. Assume that tensors are of 24 | same dense type, and that flat is given by _flatten_dense_tensors. 25 | Arguments: 26 | flat (Tensor): flattened dense tensors to unflatten. 27 | tensors (Iterable[Tensor]): dense tensors whose sizes will be used to 28 | unflatten flat. 29 | Returns: 30 | Unflattened dense tensors with sizes same as tensors and values from 31 | flat. 32 | """ 33 | outputs = [] 34 | offset = 0 35 | for tensor in tensors: 36 | numel = tensor.numel() 37 | outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) 38 | offset += numel 39 | return tuple(outputs) 40 | 41 | 42 | ''' 43 | This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py 44 | launcher included with this example. It assumes that your run is using multiprocess with 1 45 | GPU/process, that the model is on the correct device, and that torch.set_device has been 46 | used to set the device. 47 | Parameters are broadcasted to the other processes on initialization of DistributedDataParallel, 48 | and will be allreduced at the finish of the backward pass. 49 | ''' 50 | class DistributedDataParallel(Module): 51 | 52 | def __init__(self, module): 53 | super(DistributedDataParallel, self).__init__() 54 | #fallback for PyTorch 0.3 55 | if not hasattr(dist, '_backend'): 56 | self.warn_on_half = True 57 | else: 58 | self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 59 | 60 | self.module = module 61 | 62 | for p in list(self.module.state_dict().values()): 63 | if not torch.is_tensor(p): 64 | continue 65 | dist.broadcast(p, 0) 66 | 67 | def allreduce_params(): 68 | if(self.needs_reduction): 69 | self.needs_reduction = False 70 | buckets = {} 71 | for param in self.module.parameters(): 72 | if param.requires_grad and param.grad is not None: 73 | tp = type(param.data) 74 | if tp not in buckets: 75 | buckets[tp] = [] 76 | buckets[tp].append(param) 77 | if self.warn_on_half: 78 | if torch.cuda.HalfTensor in buckets: 79 | print(("WARNING: gloo dist backend for half parameters may be extremely slow." + 80 | " It is recommended to use the NCCL backend in this case. This currently requires" + 81 | "PyTorch built from top of tree master.")) 82 | self.warn_on_half = False 83 | 84 | for tp in buckets: 85 | bucket = buckets[tp] 86 | grads = [param.grad.data for param in bucket] 87 | coalesced = _flatten_dense_tensors(grads) 88 | dist.all_reduce(coalesced) 89 | coalesced /= dist.get_world_size() 90 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 91 | buf.copy_(synced) 92 | 93 | for param in list(self.module.parameters()): 94 | def allreduce_hook(*unused): 95 | param._execution_engine.queue_callback(allreduce_params) 96 | if param.requires_grad: 97 | param.register_hook(allreduce_hook) 98 | 99 | def forward(self, *inputs, **kwargs): 100 | self.needs_reduction = True 101 | return self.module(*inputs, **kwargs) 102 | 103 | ''' 104 | def _sync_buffers(self): 105 | buffers = list(self.module._all_buffers()) 106 | if len(buffers) > 0: 107 | # cross-node buffer sync 108 | flat_buffers = _flatten_dense_tensors(buffers) 109 | dist.broadcast(flat_buffers, 0) 110 | for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): 111 | buf.copy_(synced) 112 | def train(self, mode=True): 113 | # Clear NCCL communicator and CUDA event cache of the default group ID, 114 | # These cache will be recreated at the later call. This is currently a 115 | # work-around for a potential NCCL deadlock. 116 | if dist._backend == dist.dist_backend.NCCL: 117 | dist._clear_group_cache() 118 | super(DistributedDataParallel, self).train(mode) 119 | self.module.train(mode) 120 | ''' 121 | ''' 122 | Modifies existing model to do gradient allreduce, but doesn't change class 123 | so you don't need "module" 124 | ''' 125 | def apply_gradient_allreduce(module): 126 | if not hasattr(dist, '_backend'): 127 | module.warn_on_half = True 128 | else: 129 | module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 130 | 131 | for p in list(module.state_dict().values()): 132 | if not torch.is_tensor(p): 133 | continue 134 | dist.broadcast(p, 0) 135 | 136 | def allreduce_params(): 137 | if(module.needs_reduction): 138 | module.needs_reduction = False 139 | buckets = {} 140 | for param in module.parameters(): 141 | if param.requires_grad and param.grad is not None: 142 | tp = type(param.data) 143 | if tp not in buckets: 144 | buckets[tp] = [] 145 | buckets[tp].append(param) 146 | if module.warn_on_half: 147 | if torch.cuda.HalfTensor in buckets: 148 | print(("WARNING: gloo dist backend for half parameters may be extremely slow." + 149 | " It is recommended to use the NCCL backend in this case. This currently requires" + 150 | "PyTorch built from top of tree master.")) 151 | module.warn_on_half = False 152 | 153 | for tp in buckets: 154 | bucket = buckets[tp] 155 | grads = [param.grad.data for param in bucket] 156 | coalesced = _flatten_dense_tensors(grads) 157 | dist.all_reduce(coalesced) 158 | coalesced /= dist.get_world_size() 159 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 160 | buf.copy_(synced) 161 | 162 | for param in list(module.parameters()): 163 | def allreduce_hook(*unused): 164 | Variable._execution_engine.queue_callback(allreduce_params) 165 | if param.requires_grad: 166 | param.register_hook(allreduce_hook) 167 | 168 | def set_needs_reduction(self, input, output): 169 | self.needs_reduction = True 170 | 171 | module.register_forward_hook(set_needs_reduction) 172 | return module -------------------------------------------------------------------------------- /pre-train/hparams.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | #from text import symbols 3 | 4 | def create_hparams(hparams_string=None, verbose=False): 5 | """Create model hyperparameters. Parse nondefault from given string.""" 6 | 7 | hparams = tf.contrib.training.HParams( 8 | ################################ 9 | # Experiment Parameters # 10 | ################################ 11 | epochs=200, 12 | iters_per_checkpoint=1000, 13 | seed=1234, 14 | distributed_run=False, 15 | dist_backend="nccl", 16 | dist_url="tcp://localhost:54321", 17 | cudnn_enabled=True, 18 | cudnn_benchmark=False, 19 | 20 | ################################ 21 | # Data Parameters # 22 | ################################ 23 | training_list='/home/jxzhang/Documents/DataSets/VCTK/list/train_english_extend_no_indian.list', 24 | validation_list='/home/jxzhang/Documents/DataSets/VCTK/list/eval_english_extend_no_indian.list', 25 | mel_mean_std='/home/jxzhang/Documents/DataSets/VCTK/mel_mean_std.npy', 26 | 27 | ################################ 28 | # Data Parameters # 29 | ################################ 30 | n_mel_channels=80, 31 | n_spc_channels=1025, 32 | n_symbols=41, # 33 | n_speakers=99, # 34 | predict_spectrogram=False, 35 | 36 | ################################ 37 | # Model Parameters # 38 | ################################ 39 | 40 | symbols_embedding_dim=512, 41 | 42 | # Text Encoder parameters 43 | encoder_kernel_size=5, 44 | encoder_n_convolutions=3, 45 | encoder_embedding_dim=512, 46 | text_encoder_dropout=0.5, 47 | 48 | # Audio Encoder parameters 49 | spemb_input=False, 50 | n_frames_per_step_encoder=2, 51 | audio_encoder_hidden_dim=512, 52 | AE_attention_dim=128, 53 | AE_attention_location_n_filters=32, 54 | AE_attention_location_kernel_size=51, 55 | beam_width=10, 56 | 57 | # hidden activation 58 | # relu linear tanh 59 | hidden_activation='tanh', 60 | 61 | #Speaker Encoder parameters 62 | speaker_encoder_hidden_dim=256, 63 | speaker_encoder_dropout=0.2, 64 | speaker_embedding_dim=128, 65 | 66 | 67 | #Speaker Classifier parameters 68 | SC_hidden_dim=512, 69 | SC_n_convolutions=3, 70 | SC_kernel_size=1, 71 | 72 | # Decoder parameters 73 | feed_back_last=True, 74 | n_frames_per_step_decoder=2, 75 | decoder_rnn_dim=512, 76 | prenet_dim=[256,256], 77 | max_decoder_steps=1000, 78 | stop_threshold=0.5, 79 | 80 | # Attention parameters 81 | attention_rnn_dim=512, 82 | attention_dim=128, 83 | 84 | # Location Layer parameters 85 | attention_location_n_filters=32, 86 | attention_location_kernel_size=17, 87 | 88 | # PostNet parameters 89 | postnet_n_convolutions=5, 90 | postnet_dim=512, 91 | postnet_kernel_size=5, 92 | postnet_dropout=0.5, 93 | 94 | ################################ 95 | # Optimization Hyperparameters # 96 | ################################ 97 | use_saved_learning_rate=False, 98 | learning_rate=1e-3, 99 | weight_decay=1e-6, 100 | grad_clip_thresh=5.0, 101 | batch_size=32, 102 | 103 | contrastive_loss_w=30.0, 104 | speaker_encoder_loss_w=1.0, 105 | text_classifier_loss_w=1.0, 106 | speaker_adversial_loss_w=20., 107 | speaker_classifier_loss_w=0.1, 108 | ce_loss=False 109 | ) 110 | 111 | if hparams_string: 112 | tf.logging.info('Parsing command line hparams: %s', hparams_string) 113 | hparams.parse(hparams_string) 114 | 115 | if verbose: 116 | tf.logging.info('Final parsed hparams: %s', list(hparams.values())) 117 | 118 | return hparams 119 | -------------------------------------------------------------------------------- /pre-train/inference.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import matplotlib.pylab as plt 4 | 5 | 6 | import os 7 | import librosa 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import DataLoader 11 | 12 | from reader import TextMelIDLoader, TextMelIDCollate, id2ph, id2sp 13 | from hparams import create_hparams 14 | from model import Parrot, lcm 15 | from train import load_model 16 | import scipy.io.wavfile 17 | 18 | 19 | ########### Configuration ########### 20 | hparams = create_hparams() 21 | 22 | #generation list 23 | hlist = '/home/jxzhang/Documents/DataSets/VCTK/list/hold_english.list' 24 | tlist = '/home/jxzhang/Documents/DataSets/VCTK/list/eval_english.list' 25 | 26 | # use seen (tlist) or unseen list (hlist) 27 | test_list = tlist 28 | checkpoint_path='outdir/checkpoint_0' 29 | # TTS or VC task? 30 | input_text=False 31 | # number of utterances for generation 32 | NUM=10 33 | ISMEL=(not hparams.predict_spectrogram) 34 | ##################################### 35 | 36 | def plot_data(data, fn, figsize=(12, 4)): 37 | fig, axes = plt.subplots(1, len(data), figsize=figsize) 38 | for i in range(len(data)): 39 | if len(data) == 1: 40 | ax = axes 41 | else: 42 | ax = axes[i] 43 | g = ax.imshow(data[i], aspect='auto', origin='bottom', 44 | interpolation='none') 45 | plt.colorbar(g, ax=ax) 46 | plt.savefig(fn) 47 | 48 | 49 | model = load_model(hparams) 50 | 51 | model.load_state_dict(torch.load(checkpoint_path)['state_dict']) 52 | _ = model.eval() 53 | 54 | test_set = TextMelIDLoader(test_list, hparams.mel_mean_std, shuffle=True) 55 | sample_list = test_set.file_path_list 56 | collate_fn = TextMelIDCollate(lcm(hparams.n_frames_per_step_encoder, 57 | hparams.n_frames_per_step_decoder)) 58 | 59 | test_loader = DataLoader(test_set, num_workers=1, shuffle=False, 60 | sampler=None, 61 | batch_size=1, pin_memory=False, 62 | drop_last=True, collate_fn=collate_fn) 63 | 64 | 65 | 66 | task = 'tts' if input_text else 'vc' 67 | path_save = os.path.join(checkpoint_path.replace('checkpoint', 'test'), task) 68 | path_save += '_seen' if test_list == tlist else '_unseen' 69 | if not os.path.exists(path_save): 70 | os.makedirs(path_save) 71 | 72 | print(path_save) 73 | 74 | def recover_wav(mel, wav_path, ismel=False, 75 | n_fft=2048, win_length=800,hop_length=200): 76 | 77 | if ismel: 78 | mean, std = np.load(hparams.mel_mean_std) 79 | else: 80 | mean, std = np.load(hparams.mel_mean_std.replace('mel','spec')) 81 | 82 | mean = mean[:,None] 83 | std = std[:,None] 84 | mel = 1.2 * mel * std + mean 85 | mel = np.exp(mel) 86 | 87 | if ismel: 88 | filters = librosa.filters.mel(sr=16000, n_fft=2048, n_mels=80) 89 | inv_filters = np.linalg.pinv(filters) 90 | spec = np.dot(inv_filters, mel) 91 | else: 92 | spec = mel 93 | 94 | def _griffin_lim(stftm_matrix, shape, max_iter=50): 95 | y = np.random.random(shape) 96 | for i in range(max_iter): 97 | stft_matrix = librosa.core.stft(y, n_fft=n_fft, win_length=win_length, hop_length=hop_length) 98 | stft_matrix = stftm_matrix * stft_matrix / np.abs(stft_matrix) 99 | y = librosa.core.istft(stft_matrix, win_length=win_length, hop_length=hop_length) 100 | return y 101 | 102 | shape = spec.shape[1] * hop_length - hop_length + 1 103 | 104 | y = _griffin_lim(spec, shape) 105 | scipy.io.wavfile.write(wav_path, 16000, y) 106 | return y 107 | 108 | 109 | text_input, mel, spec, speaker_id = test_set[0] 110 | reference_mel = mel.cuda().unsqueeze(0) 111 | ref_sp = id2sp[speaker_id.item()] 112 | 113 | def levenshteinDistance(s1, s2): 114 | if len(s1) > len(s2): 115 | s1, s2 = s2, s1 116 | 117 | distances = list(range(len(s1) + 1)) 118 | for i2, c2 in enumerate(s2): 119 | distances_ = [i2+1] 120 | for i1, c1 in enumerate(s1): 121 | if c1 == c2: 122 | distances_.append(distances[i1]) 123 | else: 124 | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) 125 | distances = distances_ 126 | return distances[-1] 127 | 128 | with torch.no_grad(): 129 | 130 | errs = 0 131 | totalphs = 0 132 | 133 | for i, batch in enumerate(test_loader): 134 | if i == NUM: 135 | break 136 | 137 | sample_id = sample_list[i].split('/')[-1][9:17] 138 | print(('%d index %s, decoding ...'%(i,sample_id))) 139 | 140 | x, y = model.parse_batch(batch) 141 | predicted_mel, post_output, predicted_stop, alignments, \ 142 | text_hidden, audio_seq2seq_hidden, audio_seq2seq_phids, audio_seq2seq_alignments, \ 143 | speaker_id = model.inference(x, input_text, reference_mel, hparams.beam_width) 144 | 145 | post_output = post_output.data.cpu().numpy()[0] 146 | alignments = alignments.data.cpu().numpy()[0].T 147 | audio_seq2seq_alignments = audio_seq2seq_alignments.data.cpu().numpy()[0].T 148 | 149 | text_hidden = text_hidden.data.cpu().numpy()[0].T #-> [hidden_dim, max_text_len] 150 | audio_seq2seq_hidden = audio_seq2seq_hidden.data.cpu().numpy()[0].T 151 | audio_seq2seq_phids = audio_seq2seq_phids.data.cpu().numpy()[0] # [T + 1] 152 | speaker_id = speaker_id.data.cpu().numpy()[0] # scalar 153 | 154 | task = 'TTS' if input_text else 'VC' 155 | 156 | recover_wav(post_output, 157 | os.path.join(path_save, 'Wav_%s_ref_%s_%s.wav'%(sample_id, ref_sp, task)), 158 | ismel=ISMEL) 159 | 160 | post_output_path = os.path.join(path_save, 'Mel_%s_ref_%s_%s.npy'%(sample_id, ref_sp, task)) 161 | np.save(post_output_path, post_output) 162 | 163 | plot_data([alignments, audio_seq2seq_alignments], 164 | os.path.join(path_save, 'Ali_%s_ref_%s_%s.pdf'%(sample_id, ref_sp, task))) 165 | 166 | plot_data([np.hstack([text_hidden, audio_seq2seq_hidden])], 167 | os.path.join(path_save, 'Hid_%s_ref_%s_%s.pdf'%(sample_id, ref_sp, task))) 168 | 169 | audio_seq2seq_phids = [id2ph[id] for id in audio_seq2seq_phids[:-1]] 170 | target_text = y[0].data.cpu().numpy()[0] 171 | target_text = [id2ph[id] for id in target_text[:]] 172 | 173 | print('Sounds like %s, Decoded text is '%(id2sp[speaker_id])) 174 | 175 | print(audio_seq2seq_phids) 176 | print(target_text) 177 | 178 | err = levenshteinDistance(audio_seq2seq_phids, target_text) 179 | print(err, len(target_text)) 180 | 181 | errs += err 182 | totalphs += len(target_text) 183 | 184 | print(float(errs)/float(totalphs)) 185 | 186 | 187 | -------------------------------------------------------------------------------- /pre-train/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch.nn.functional as F 4 | from tensorboardX import SummaryWriter 5 | from plotting_utils import plot_alignment_to_numpy, plot_spectrogram_to_numpy, plot_alignment 6 | from plotting_utils import plot_gate_outputs_to_numpy 7 | 8 | 9 | class ParrotLogger(SummaryWriter): 10 | def __init__(self, logdir, ali_path='ali'): 11 | super(ParrotLogger, self).__init__(logdir) 12 | ali_path = os.path.join(logdir, ali_path) 13 | if not os.path.exists(ali_path): 14 | os.makedirs(ali_path) 15 | self.ali_path = ali_path 16 | 17 | def log_training(self, reduced_loss, reduced_losses, reduced_acces, grad_norm, learning_rate, duration, 18 | iteration): 19 | 20 | self.add_scalar("training.loss", reduced_loss, iteration) 21 | self.add_scalar("training.loss.recon", reduced_losses[0], iteration) 22 | self.add_scalar("training.loss.recon_post", reduced_losses[1], iteration) 23 | self.add_scalar("training.loss.stop", reduced_losses[2], iteration) 24 | self.add_scalar("training.loss.contr", reduced_losses[3], iteration) 25 | self.add_scalar("training.loss.spenc", reduced_losses[4], iteration) 26 | self.add_scalar("training.loss.spcla", reduced_losses[5], iteration) 27 | self.add_scalar("training.loss.texcl", reduced_losses[6], iteration) 28 | self.add_scalar("training.loss.spadv", reduced_losses[7], iteration) 29 | 30 | self.add_scalar("grad.norm", grad_norm, iteration) 31 | self.add_scalar("learning.rate", learning_rate, iteration) 32 | self.add_scalar("duration", duration, iteration) 33 | 34 | 35 | self.add_scalar('training.acc.spenc', reduced_acces[0], iteration) 36 | self.add_scalar('training.acc.spcla', reduced_acces[1], iteration) 37 | self.add_scalar('training.acc.texcl', reduced_acces[2], iteration) 38 | 39 | def log_validation(self, reduced_loss, reduced_losses, reduced_acces, model, y, y_pred, iteration, task): 40 | 41 | self.add_scalar('validation.loss.%s'%task, reduced_loss, iteration) 42 | self.add_scalar("validation.loss.%s.recon"%task, reduced_losses[0], iteration) 43 | self.add_scalar("validation.loss.%s.recon_post"%task, reduced_losses[1], iteration) 44 | self.add_scalar("validation.loss.%s.stop"%task, reduced_losses[2], iteration) 45 | self.add_scalar("validation.loss.%s.contr"%task, reduced_losses[3], iteration) 46 | self.add_scalar("validation.loss.%s.spenc"%task, reduced_losses[4], iteration) 47 | self.add_scalar("validation.loss.%s.spcla"%task, reduced_losses[5], iteration) 48 | self.add_scalar("validation.loss.%s.texcl"%task, reduced_losses[6], iteration) 49 | self.add_scalar("validation.loss.%s.spadv"%task, reduced_losses[7], iteration) 50 | 51 | self.add_scalar('validation.acc.%s.spenc'%task, reduced_acces[0], iteration) 52 | self.add_scalar('validation.acc.%s.spcla'%task, reduced_acces[1], iteration) 53 | self.add_scalar('validatoin.acc.%s.texcl'%task, reduced_acces[2], iteration) 54 | 55 | predicted_mel, post_output, predicted_stop, alignments, \ 56 | text_hidden, mel_hidden, text_logit_from_mel_hidden, \ 57 | audio_seq2seq_alignments, \ 58 | speaker_logit_from_mel, speaker_logit_from_mel_hidden, \ 59 | text_lengths, mel_lengths = y_pred 60 | 61 | text_target, mel_target, spc_target, speaker_target, stop_target = y 62 | 63 | stop_target = stop_target.reshape(stop_target.size(0), -1, int(stop_target.size(1)/predicted_stop.size(1))) 64 | stop_target = stop_target[:,:,0] 65 | 66 | # plot distribution of parameters 67 | #for tag, value in model.named_parameters(): 68 | # tag = tag.replace('.', '/') 69 | # self.add_histogram(tag, value.data.cpu().numpy(), iteration) 70 | 71 | # plot alignment, mel target and predicted, stop target and predicted 72 | idx = random.randint(0, alignments.size(0) - 1) 73 | 74 | alignments = alignments.data.cpu().numpy() 75 | audio_seq2seq_alignments = audio_seq2seq_alignments.data.cpu().numpy() 76 | 77 | self.add_image( 78 | "%s.alignment"%task, 79 | plot_alignment_to_numpy(alignments[idx].T), 80 | iteration, dataformats='HWC') 81 | 82 | # plot more alignments 83 | plot_alignment(alignments[:4], self.ali_path+'/step-%d-%s.pdf'%(iteration, task)) 84 | 85 | self.add_image( 86 | "%s.audio_seq2seq_alignment"%task, 87 | plot_alignment_to_numpy(audio_seq2seq_alignments[idx].T), 88 | iteration, dataformats='HWC') 89 | 90 | self.add_image( 91 | "%s.mel_target"%task, 92 | plot_spectrogram_to_numpy(mel_target[idx].data.cpu().numpy()), 93 | iteration, dataformats='HWC') 94 | 95 | self.add_image( 96 | "%s.mel_predicted"%task, 97 | plot_spectrogram_to_numpy(predicted_mel[idx].data.cpu().numpy()), 98 | iteration, dataformats='HWC') 99 | 100 | self.add_image( 101 | "%s.spc_target"%task, 102 | plot_spectrogram_to_numpy(spc_target[idx].data.cpu().numpy()), 103 | iteration, dataformats='HWC') 104 | 105 | self.add_image( 106 | "%s.post_predicted"%task, 107 | plot_spectrogram_to_numpy(post_output[idx].data.cpu().numpy()), 108 | iteration, dataformats='HWC') 109 | 110 | self.add_image( 111 | "%s.stop"%task, 112 | plot_gate_outputs_to_numpy( 113 | stop_target[idx].data.cpu().numpy(), 114 | F.sigmoid(predicted_stop[idx]).data.cpu().numpy()), 115 | iteration, dataformats='HWC') 116 | -------------------------------------------------------------------------------- /pre-train/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Parrot 2 | from .loss import ParrotLoss 3 | from .utils import lcm,gcd -------------------------------------------------------------------------------- /pre-train/model/basic_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def tile(x, count, dim=0): 7 | """ 8 | Tiles x on dimension dim count times. 9 | """ 10 | perm = list(range(len(x.size()))) 11 | if dim != 0: 12 | perm[0], perm[dim] = perm[dim], perm[0] 13 | x = x.permute(perm).contiguous() 14 | out_size = list(x.size()) 15 | out_size[0] *= count 16 | batch = x.size(0) 17 | x = x.view(batch, -1) \ 18 | .transpose(0, 1) \ 19 | .repeat(count, 1) \ 20 | .transpose(0, 1) \ 21 | .contiguous() \ 22 | .view(*out_size) 23 | if dim != 0: 24 | x = x.permute(perm).contiguous() 25 | return x 26 | 27 | 28 | def sort_batch(data, lengths): 29 | ''' 30 | sort data by length 31 | sorted_data[initial_index] == data 32 | ''' 33 | sorted_lengths, sorted_index = lengths.sort(0, descending=True) 34 | sorted_data = data[sorted_index] 35 | _, initial_index = sorted_index.sort(0, descending=False) 36 | 37 | return sorted_data, sorted_lengths, initial_index 38 | 39 | 40 | class LinearNorm(torch.nn.Module): 41 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 42 | super(LinearNorm, self).__init__() 43 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 44 | 45 | torch.nn.init.xavier_uniform_( 46 | self.linear_layer.weight, 47 | gain=torch.nn.init.calculate_gain(w_init_gain)) 48 | 49 | def forward(self, x): 50 | return self.linear_layer(x) 51 | 52 | 53 | class ConvNorm(torch.nn.Module): 54 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 55 | padding=None, dilation=1, bias=True, w_init_gain='linear', param=None): 56 | super(ConvNorm, self).__init__() 57 | if padding is None: 58 | assert(kernel_size % 2 == 1) 59 | padding = int(dilation * (kernel_size - 1) / 2) 60 | 61 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 62 | kernel_size=kernel_size, stride=stride, 63 | padding=padding, dilation=dilation, 64 | bias=bias) 65 | 66 | torch.nn.init.xavier_uniform_( 67 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) 68 | 69 | def forward(self, signal): 70 | conv_signal = self.conv(signal) 71 | return conv_signal 72 | 73 | 74 | class Prenet(nn.Module): 75 | def __init__(self, in_dim, sizes): 76 | super(Prenet, self).__init__() 77 | in_sizes = [in_dim] + sizes[:-1] 78 | self.layers = nn.ModuleList( 79 | [LinearNorm(in_size, out_size, bias=False) 80 | for (in_size, out_size) in zip(in_sizes, sizes)]) 81 | 82 | def forward(self, x): 83 | for linear in self.layers: 84 | x = F.dropout(F.relu(linear(x)), p=0.5, training=True) 85 | return x 86 | 87 | 88 | class LocationLayer(nn.Module): 89 | def __init__(self, attention_n_filters, attention_kernel_size, 90 | attention_dim): 91 | super(LocationLayer, self).__init__() 92 | padding = int((attention_kernel_size - 1) / 2) 93 | self.location_conv = ConvNorm(2, attention_n_filters, 94 | kernel_size=attention_kernel_size, 95 | padding=padding, bias=False, stride=1, 96 | dilation=1) 97 | self.location_dense = LinearNorm(attention_n_filters, attention_dim, 98 | bias=False, w_init_gain='tanh') 99 | 100 | def forward(self, attention_weights_cat): 101 | processed_attention = self.location_conv(attention_weights_cat) 102 | processed_attention = processed_attention.transpose(1, 2) 103 | processed_attention = self.location_dense(processed_attention) 104 | return processed_attention 105 | 106 | 107 | class Attention(nn.Module): 108 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, 109 | attention_location_n_filters, attention_location_kernel_size): 110 | super(Attention, self).__init__() 111 | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, 112 | bias=False, w_init_gain='tanh') 113 | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, 114 | w_init_gain='tanh') 115 | self.v = LinearNorm(attention_dim, 1, bias=False) 116 | self.location_layer = LocationLayer(attention_location_n_filters, 117 | attention_location_kernel_size, 118 | attention_dim) 119 | self.score_mask_value = -float("inf") 120 | 121 | def get_alignment_energies(self, query, processed_memory, 122 | attention_weights_cat): 123 | """ 124 | PARAMS 125 | ------ 126 | query: decoder output (batch, n_mel_channels * n_frames_per_step) 127 | processed_memory: processed encoder outputs (B, T_in, attention_dim) 128 | attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) 129 | RETURNS 130 | ------- 131 | alignment (batch, max_time) 132 | """ 133 | 134 | processed_query = self.query_layer(query.unsqueeze(1)) 135 | processed_attention_weights = self.location_layer(attention_weights_cat) 136 | energies = self.v(torch.tanh( 137 | processed_query + processed_attention_weights + processed_memory)) 138 | 139 | energies = energies.squeeze(-1) 140 | return energies 141 | 142 | def forward(self, attention_hidden_state, memory, processed_memory, 143 | attention_weights_cat, mask): 144 | """ 145 | PARAMS 146 | ------ 147 | attention_hidden_state: attention rnn last output 148 | memory: encoder outputs 149 | processed_memory: processed encoder outputs 150 | attention_weights_cat: previous and cummulative attention weights 151 | mask: binary mask for padded data 152 | """ 153 | alignment = self.get_alignment_energies( 154 | attention_hidden_state, processed_memory, attention_weights_cat) 155 | 156 | if mask is not None: 157 | alignment.data.masked_fill_(mask, self.score_mask_value) 158 | 159 | attention_weights = F.softmax(alignment, dim=1) 160 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) 161 | attention_context = attention_context.squeeze(1) 162 | 163 | return attention_context, attention_weights 164 | 165 | 166 | class ForwardAttentionV2(nn.Module): 167 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, 168 | attention_location_n_filters, attention_location_kernel_size): 169 | super(ForwardAttentionV2, self).__init__() 170 | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, 171 | bias=False, w_init_gain='tanh') 172 | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, 173 | w_init_gain='tanh') 174 | self.v = LinearNorm(attention_dim, 1, bias=False) 175 | self.location_layer = LocationLayer(attention_location_n_filters, 176 | attention_location_kernel_size, 177 | attention_dim) 178 | self.score_mask_value = -float(1e20) 179 | 180 | def get_alignment_energies(self, query, processed_memory, 181 | attention_weights_cat): 182 | """ 183 | PARAMS 184 | ------ 185 | query: decoder output (batch, n_mel_channels * n_frames_per_step) 186 | processed_memory: processed encoder outputs (B, T_in, attention_dim) 187 | attention_weights_cat: prev. and cumulative att weights (B, 2, max_time) 188 | RETURNS 189 | ------- 190 | alignment (batch, max_time) 191 | """ 192 | 193 | processed_query = self.query_layer(query.unsqueeze(1)) 194 | processed_attention_weights = self.location_layer(attention_weights_cat) 195 | energies = self.v(torch.tanh( 196 | processed_query + processed_attention_weights + processed_memory)) 197 | 198 | energies = energies.squeeze(-1) 199 | return energies 200 | 201 | def forward(self, attention_hidden_state, memory, processed_memory, 202 | attention_weights_cat, mask, log_alpha): 203 | """ 204 | PARAMS 205 | ------ 206 | attention_hidden_state: attention rnn last output 207 | memory: encoder outputs 208 | processed_memory: processed encoder outputs 209 | attention_weights_cat: previous and cummulative attention weights 210 | mask: binary mask for padded data 211 | """ 212 | log_energy = self.get_alignment_energies( 213 | attention_hidden_state, processed_memory, attention_weights_cat) 214 | 215 | #log_energy = 216 | 217 | if mask is not None: 218 | log_energy.data.masked_fill_(mask, self.score_mask_value) 219 | 220 | #attention_weights = F.softmax(alignment, dim=1) 221 | 222 | #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME] 223 | #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1] 224 | 225 | #log_total_score = log_alpha + content_score 226 | 227 | #previous_attention_weights = attention_weights_cat[:,0,:] 228 | 229 | log_alpha_shift_padded = [] 230 | max_time = log_energy.size(1) 231 | for sft in range(2): 232 | shifted = log_alpha[:,:max_time-sft] 233 | shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value) 234 | log_alpha_shift_padded.append(shift_padded.unsqueeze(2)) 235 | 236 | biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2) 237 | 238 | log_alpha_new = biased + log_energy 239 | 240 | attention_weights = F.softmax(log_alpha_new, dim=1) 241 | 242 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) 243 | attention_context = attention_context.squeeze(1) 244 | 245 | return attention_context, attention_weights, log_alpha_new -------------------------------------------------------------------------------- /pre-train/model/beam.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | from .penalties import PenaltyBuilder 5 | 6 | 7 | 8 | class Beam(object): 9 | """ 10 | ''' 11 | adapt from opennmt 12 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/translate/beam.py 13 | ''' 14 | 15 | Class for managing the internals of the beam search process. 16 | Takes care of beams, back pointers, and scores. 17 | Args: 18 | size (int): beam size 19 | pad, bos, eos (int): indices of padding, beginning, and ending. 20 | n_best (int): nbest size to use 21 | cuda (bool): use gpu 22 | global_scorer (:obj:`GlobalScorer`) 23 | """ 24 | 25 | def __init__(self, size, pad, bos, eos, 26 | n_best=1, cuda=False, 27 | global_scorer=None, 28 | min_length=0, 29 | stepwise_penalty=False, 30 | block_ngram_repeat=0, 31 | exclusion_tokens=set()): 32 | 33 | self.size = size 34 | self.tt = torch.cuda if cuda else torch 35 | 36 | # The score for each translation on the beam. 37 | self.scores = self.tt.FloatTensor(size).zero_() 38 | self.all_scores = [] 39 | 40 | # The backpointers at each time-step. 41 | self.prev_ks = [] 42 | 43 | # The outputs at each time-step. 44 | self.next_ys = [self.tt.LongTensor(size) 45 | .fill_(pad)] 46 | self.next_ys[0][0] = bos 47 | 48 | # Has EOS topped the beam yet. 49 | self._eos = eos 50 | self.eos_top = False 51 | 52 | # The attentions (matrix) for each time. 53 | self.attn = [] 54 | self.hidden = [] 55 | 56 | # Time and k pair for finished. 57 | self.finished = [] 58 | self.n_best = n_best 59 | 60 | # Information for global scoring. 61 | self.global_scorer = global_scorer 62 | self.global_state = {} 63 | 64 | # Minimum prediction length 65 | self.min_length = min_length 66 | 67 | # Apply Penalty at every step 68 | self.stepwise_penalty = stepwise_penalty 69 | self.block_ngram_repeat = block_ngram_repeat 70 | self.exclusion_tokens = exclusion_tokens 71 | 72 | def get_current_state(self): 73 | "Get the outputs for the current timestep." 74 | return self.next_ys[-1] 75 | 76 | def get_current_origin(self): 77 | "Get the backpointers for the current timestep." 78 | return self.prev_ks[-1] 79 | 80 | def advance(self, word_probs, attn_out, hidden): 81 | """ 82 | Given prob over words for every last beam `wordLk` and attention 83 | `attn_out`: Compute and update the beam search. 84 | Parameters: 85 | * `word_probs`- probs of advancing from the last step (K x words) 86 | * `attn_out`- attention at the last step 87 | Returns: True if beam search is complete. 88 | """ 89 | num_words = word_probs.size(1) 90 | if self.stepwise_penalty: 91 | self.global_scorer.update_score(self, attn_out) 92 | # force the output to be longer than self.min_length 93 | cur_len = len(self.next_ys) 94 | if cur_len < self.min_length: 95 | for k in range(len(word_probs)): 96 | word_probs[k][self._eos] = -1e20 97 | # Sum the previous scores. 98 | if len(self.prev_ks) > 0: 99 | beam_scores = word_probs + self.scores.unsqueeze(1) 100 | # Don't let EOS have children. 101 | for i in range(self.next_ys[-1].size(0)): 102 | if self.next_ys[-1][i] == self._eos: 103 | beam_scores[i] = -1e20 104 | 105 | # Block ngram repeats 106 | if self.block_ngram_repeat > 0: 107 | ngrams = [] 108 | le = len(self.next_ys) 109 | for j in range(self.next_ys[-1].size(0)): 110 | hyp, _ = self.get_hyp(le - 1, j) 111 | ngrams = set() 112 | fail = False 113 | gram = [] 114 | for i in range(le - 1): 115 | # Last n tokens, n = block_ngram_repeat 116 | gram = (gram + 117 | [hyp[i].item()])[-self.block_ngram_repeat:] 118 | # Skip the blocking if it is in the exclusion list 119 | if set(gram) & self.exclusion_tokens: 120 | continue 121 | if tuple(gram) in ngrams: 122 | fail = True 123 | ngrams.add(tuple(gram)) 124 | if fail: 125 | beam_scores[j] = -10e20 126 | else: 127 | beam_scores = word_probs[0] 128 | flat_beam_scores = beam_scores.view(-1) 129 | best_scores, best_scores_id = flat_beam_scores.topk(self.size, 0, 130 | True, True) 131 | 132 | self.all_scores.append(self.scores) 133 | self.scores = best_scores 134 | 135 | # best_scores_id is flattened beam x word array, so calculate which 136 | # word and beam each score came from 137 | prev_k = best_scores_id / num_words 138 | self.prev_ks.append(prev_k) 139 | self.next_ys.append((best_scores_id - prev_k * num_words)) 140 | self.attn.append(attn_out.index_select(0, prev_k)) 141 | self.hidden.append(hidden.index_select(0, prev_k)) 142 | self.global_scorer.update_global_state(self) 143 | 144 | for i in range(self.next_ys[-1].size(0)): 145 | if self.next_ys[-1][i] == self._eos: 146 | global_scores = self.global_scorer.score(self, self.scores) 147 | s = global_scores[i] 148 | self.finished.append((s, len(self.next_ys) - 1, i)) 149 | 150 | # End condition is when top-of-beam is EOS and no global score. 151 | if self.next_ys[-1][0] == self._eos: 152 | self.all_scores.append(self.scores) 153 | self.eos_top = True 154 | 155 | def done(self): 156 | return self.eos_top and len(self.finished) >= self.n_best 157 | 158 | def sort_finished(self, minimum=None): 159 | if minimum is not None: 160 | i = 0 161 | # Add from beam until we have minimum outputs. 162 | while len(self.finished) < minimum: 163 | global_scores = self.global_scorer.score(self, self.scores) 164 | s = global_scores[i] 165 | self.finished.append((s, len(self.next_ys) - 1, i)) 166 | i += 1 167 | 168 | self.finished.sort(key=lambda a: -a[0]) 169 | scores = [sc for sc, _, _ in self.finished] 170 | ks = [(t, k) for _, t, k in self.finished] 171 | return scores, ks 172 | 173 | def get_hyp(self, timestep, k): 174 | """ 175 | Walk back to construct the full hypothesis. 176 | """ 177 | hyp, attn, hidden = [], [], [] 178 | for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1): 179 | hyp.append(self.next_ys[j + 1][k]) 180 | attn.append(self.attn[j][k]) 181 | hidden.append(self.hidden[j][k]) 182 | k = self.prev_ks[j][k] 183 | return torch.stack(hyp[::-1]), torch.stack(attn[::-1]), torch.stack(hidden[::-1]) 184 | 185 | 186 | class GNMTGlobalScorer(object): 187 | """ 188 | NMT re-ranking score from 189 | "Google's Neural Machine Translation System" :cite:`wu2016google` 190 | Args: 191 | alpha (float): length parameter 192 | beta (float): coverage parameter 193 | """ 194 | 195 | def __init__(self, opt=None): 196 | self.alpha = 0. 197 | self.beta = 0. 198 | penalty_builder = PenaltyBuilder('none', 199 | 'avg') 200 | # Term will be subtracted from probability 201 | self.cov_penalty = penalty_builder.coverage_penalty() 202 | # Probability will be divided by this 203 | self.length_penalty = penalty_builder.length_penalty() 204 | 205 | def score(self, beam, logprobs): 206 | """ 207 | Rescores a prediction based on penalty functions 208 | """ 209 | normalized_probs = self.length_penalty(beam, 210 | logprobs, 211 | self.alpha) 212 | if not beam.stepwise_penalty: 213 | penalty = self.cov_penalty(beam, 214 | beam.global_state["coverage"], 215 | self.beta) 216 | normalized_probs -= penalty 217 | 218 | return normalized_probs 219 | 220 | def update_score(self, beam, attn): 221 | """ 222 | Function to update scores of a Beam that is not finished 223 | """ 224 | if "prev_penalty" in list(beam.global_state.keys()): 225 | beam.scores.add_(beam.global_state["prev_penalty"]) 226 | penalty = self.cov_penalty(beam, 227 | beam.global_state["coverage"] + attn, 228 | self.beta) 229 | beam.scores.sub_(penalty) 230 | 231 | def update_global_state(self, beam): 232 | "Keeps the coverage vector as sum of attentions" 233 | if len(beam.prev_ks) == 1: 234 | beam.global_state["prev_penalty"] = beam.scores.clone().fill_(0.0) 235 | beam.global_state["coverage"] = beam.attn[-1] 236 | self.cov_total = beam.attn[-1].sum(1) 237 | else: 238 | self.cov_total += torch.min(beam.attn[-1], 239 | beam.global_state['coverage']).sum(1) 240 | beam.global_state["coverage"] = beam.global_state["coverage"] \ 241 | .index_select(0, beam.prev_ks[-1]).add(beam.attn[-1]) 242 | 243 | prev_penalty = self.cov_penalty(beam, 244 | beam.global_state["coverage"], 245 | self.beta) 246 | beam.global_state["prev_penalty"] = prev_penalty -------------------------------------------------------------------------------- /pre-train/model/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .basic_layers import ConvNorm, LinearNorm, ForwardAttentionV2, Prenet 6 | from .utils import get_mask_from_lengths 7 | 8 | 9 | class Decoder(nn.Module): 10 | def __init__(self, hparams): 11 | super(Decoder, self).__init__() 12 | self.n_mel_channels = hparams.n_mel_channels 13 | self.n_frames_per_step = hparams.n_frames_per_step_decoder 14 | self.hidden_cat_dim = hparams.encoder_embedding_dim + hparams.speaker_embedding_dim 15 | self.attention_rnn_dim = hparams.attention_rnn_dim 16 | self.decoder_rnn_dim = hparams.decoder_rnn_dim 17 | self.prenet_dim = hparams.prenet_dim 18 | self.max_decoder_steps = hparams.max_decoder_steps 19 | self.stop_threshold = hparams.stop_threshold 20 | self.feed_back_last = hparams.feed_back_last 21 | 22 | if hparams.feed_back_last: 23 | prenet_input_dim = hparams.n_mel_channels 24 | else: 25 | prenet_input_dim = hparams.n_mel_channels * hparams.n_frames_per_step_decoder 26 | 27 | self.prenet = Prenet( 28 | prenet_input_dim , 29 | hparams.prenet_dim) 30 | 31 | self.attention_rnn = nn.LSTMCell( 32 | hparams.prenet_dim[-1] + self.hidden_cat_dim, 33 | hparams.attention_rnn_dim) 34 | 35 | self.attention_layer = ForwardAttentionV2( 36 | hparams.attention_rnn_dim, 37 | self.hidden_cat_dim, 38 | hparams.attention_dim, hparams.attention_location_n_filters, 39 | hparams.attention_location_kernel_size) 40 | 41 | self.decoder_rnn = nn.LSTMCell( 42 | self.hidden_cat_dim + hparams.attention_rnn_dim, 43 | hparams.decoder_rnn_dim) 44 | 45 | self.linear_projection = LinearNorm( 46 | self.hidden_cat_dim + hparams.decoder_rnn_dim, 47 | hparams.n_mel_channels * hparams.n_frames_per_step_decoder) 48 | 49 | self.stop_layer = LinearNorm( 50 | self.hidden_cat_dim + hparams.decoder_rnn_dim, 1, 51 | bias=True, w_init_gain='sigmoid') 52 | 53 | def get_go_frame(self, memory): 54 | """ Gets all zeros frames to use as first decoder input 55 | PARAMS 56 | ------ 57 | memory: decoder outputs 58 | RETURNS 59 | ------- 60 | decoder_input: all zeros frames 61 | """ 62 | B = memory.size(0) 63 | if self.feed_back_last: 64 | input_dim = self.n_mel_channels 65 | else: 66 | input_dim = self.n_mel_channels * self.n_frames_per_step 67 | 68 | decoder_input = Variable(memory.data.new( 69 | B, input_dim).zero_()) 70 | return decoder_input 71 | 72 | def initialize_decoder_states(self, memory, mask): 73 | """ Initializes attention rnn states, decoder rnn states, attention 74 | weights, attention cumulative weights, attention context, stores memory 75 | and stores processed memory 76 | PARAMS 77 | ------ 78 | memory: Encoder outputs 79 | mask: Mask for padded data if training, expects None for inference 80 | """ 81 | B = memory.size(0) 82 | MAX_TIME = memory.size(1) 83 | 84 | self.attention_hidden = Variable(memory.data.new( 85 | B, self.attention_rnn_dim).zero_()) 86 | self.attention_cell = Variable(memory.data.new( 87 | B, self.attention_rnn_dim).zero_()) 88 | 89 | self.decoder_hidden = Variable(memory.data.new( 90 | B, self.decoder_rnn_dim).zero_()) 91 | self.decoder_cell = Variable(memory.data.new( 92 | B, self.decoder_rnn_dim).zero_()) 93 | 94 | self.attention_weights = Variable(memory.data.new( 95 | B, MAX_TIME).zero_()) 96 | self.attention_weights_cum = Variable(memory.data.new( 97 | B, MAX_TIME).zero_()) 98 | self.attention_context = Variable(memory.data.new( 99 | B, self.hidden_cat_dim).zero_()) 100 | 101 | self.log_alpha = Variable(memory.data.new(B, MAX_TIME).fill_(-float(1e20))) 102 | self.log_alpha[:, 0].fill_(0.) 103 | 104 | self.memory = memory 105 | self.processed_memory = self.attention_layer.memory_layer(memory) 106 | self.mask = mask 107 | 108 | def parse_decoder_inputs(self, decoder_inputs): 109 | """ Prepares decoder inputs, i.e. mel outputs 110 | PARAMS 111 | ------ 112 | decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs 113 | RETURNS 114 | ------- 115 | inputs: processed decoder inputs 116 | """ 117 | # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels) 118 | decoder_inputs = decoder_inputs.transpose(1, 2) 119 | decoder_inputs = decoder_inputs.reshape( 120 | decoder_inputs.size(0), 121 | int(decoder_inputs.size(1)/self.n_frames_per_step), -1) 122 | # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels) 123 | decoder_inputs = decoder_inputs.transpose(0, 1) 124 | if self.feed_back_last: 125 | decoder_inputs = decoder_inputs[:,:,-self.n_mel_channels:] 126 | 127 | return decoder_inputs 128 | 129 | def parse_decoder_outputs(self, mel_outputs, stop_outputs, alignments): 130 | """ Prepares decoder outputs for output 131 | PARAMS 132 | ------ 133 | mel_outputs: 134 | stop_outputs: stop output energies 135 | alignments: 136 | RETURNS 137 | ------- 138 | mel_outputs: 139 | stop_outpust: stop output energies 140 | alignments: 141 | """ 142 | # (T_out, B, MAX_TIME) -> (B, T_out, MAX_TIME) 143 | alignments = torch.stack(alignments).transpose(0, 1) 144 | # (T_out, B) -> (B, T_out) 145 | if alignments.size(0) == 1: 146 | stop_outputs = torch.stack(stop_outputs).unsqueeze(0) 147 | else: 148 | stop_outputs = torch.stack(stop_outputs).transpose(0, 1) 149 | stop_outputs = stop_outputs.contiguous() 150 | # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels) 151 | mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() 152 | # decouple frames per step 153 | mel_outputs = mel_outputs.view( 154 | mel_outputs.size(0), -1, self.n_mel_channels) 155 | # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out) 156 | mel_outputs = mel_outputs.transpose(1, 2) 157 | 158 | return mel_outputs, stop_outputs, alignments 159 | 160 | def attend(self, decoder_input): 161 | cell_input = torch.cat((decoder_input, self.attention_context), -1) 162 | self.attention_hidden, self.attention_cell = self.attention_rnn( 163 | cell_input, (self.attention_hidden, self.attention_cell)) 164 | 165 | attention_weights_cat = torch.cat( 166 | (self.attention_weights.unsqueeze(1), 167 | self.attention_weights_cum.unsqueeze(1)), dim=1) 168 | 169 | self.attention_context, self.attention_weights, self.log_alpha = self.attention_layer( 170 | self.attention_hidden, self.memory, self.processed_memory, 171 | attention_weights_cat, self.mask, self.log_alpha) 172 | 173 | self.attention_weights_cum += self.attention_weights 174 | 175 | decoder_rnn_input = torch.cat( 176 | (self.attention_hidden, self.attention_context), -1) 177 | 178 | return decoder_rnn_input, self.attention_context, self.attention_weights 179 | 180 | def decode(self, decoder_input): 181 | 182 | self.decoder_hidden, self.decoder_cell = self.decoder_rnn( 183 | decoder_input, (self.decoder_hidden, self.decoder_cell)) 184 | 185 | return self.decoder_hidden 186 | 187 | def forward(self, memory, decoder_inputs, memory_lengths): 188 | """ Decoder forward pass for training 189 | PARAMS 190 | ------ 191 | memory: Encoder outputs [B, encoder_max_time, hidden_dim] 192 | decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs [B, mel_bin, T] 193 | memory_lengths: Encoder output lengths for attention masking. [B] 194 | RETURNS 195 | ------- 196 | mel_outputs: mel outputs from the decoder [B, mel_bin, T] 197 | stop_outputs: stop outputs from the decoder [B, T/r] 198 | alignments: sequence of attention weights from the decoder [B, T/r, encoder_max_time] 199 | """ 200 | 201 | decoder_input = self.get_go_frame(memory).unsqueeze(0) 202 | decoder_inputs = self.parse_decoder_inputs(decoder_inputs) 203 | decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) 204 | decoder_inputs = self.prenet(decoder_inputs) # [T/r + 1, B, prenet_dim ] 205 | 206 | self.initialize_decoder_states( 207 | memory, mask=~get_mask_from_lengths(memory_lengths)) 208 | 209 | mel_outputs, stop_outputs, alignments = [], [], [] 210 | while len(mel_outputs) < decoder_inputs.size(0) - 1: 211 | decoder_input = decoder_inputs[len(mel_outputs)] 212 | 213 | decoder_rnn_input, context, attention_weights = self.attend(decoder_input) 214 | 215 | decoder_rnn_output = self.decode(decoder_rnn_input) 216 | 217 | decoder_hidden_attention_context = torch.cat( 218 | (decoder_rnn_output, context), dim=1) 219 | 220 | mel_output = self.linear_projection(decoder_hidden_attention_context) 221 | stop_output = self.stop_layer(decoder_hidden_attention_context) 222 | 223 | mel_outputs += [mel_output.squeeze(1)] #? perhaps don't need squeeze 224 | stop_outputs += [stop_output.squeeze()] 225 | alignments += [attention_weights] 226 | 227 | mel_outputs, stop_outputs, alignments = self.parse_decoder_outputs( 228 | mel_outputs, stop_outputs, alignments) 229 | 230 | return mel_outputs, stop_outputs, alignments 231 | 232 | def inference(self, memory): 233 | """ Decoder inference 234 | PARAMS 235 | ------ 236 | memory: Encoder outputs 237 | RETURNS 238 | ------- 239 | mel_outputs: mel outputs from the decoder 240 | stop_outputs: stop outputs from the decoder 241 | alignments: sequence of attention weights from the decoder 242 | """ 243 | decoder_input = self.get_go_frame(memory) 244 | 245 | self.initialize_decoder_states(memory, mask=None) 246 | 247 | mel_outputs, stop_outputs, alignments = [], [], [] 248 | while True: 249 | decoder_input = self.prenet(decoder_input) 250 | 251 | decoder_input_final, context, alignment = self.attend(decoder_input) 252 | 253 | #mel_output, stop_output, alignment = self.decode(decoder_input) 254 | decoder_rnn_output = self.decode(decoder_input_final) 255 | decoder_hidden_attention_context = torch.cat( 256 | (decoder_rnn_output, context), dim=1) 257 | 258 | mel_output = self.linear_projection(decoder_hidden_attention_context) 259 | stop_output = self.stop_layer(decoder_hidden_attention_context) 260 | 261 | mel_outputs += [mel_output.squeeze(1)] 262 | stop_outputs += [stop_output] 263 | alignments += [alignment] 264 | 265 | 266 | if torch.sigmoid(stop_output.data) > self.stop_threshold: 267 | break 268 | elif len(mel_outputs) == self.max_decoder_steps: 269 | print("Warning! Reached max decoder steps") 270 | break 271 | 272 | if self.feed_back_last: 273 | decoder_input = mel_output[:,-self.n_mel_channels:] 274 | else: 275 | decoder_input = mel_output 276 | 277 | mel_outputs, stop_outputs, alignments = self.parse_decoder_outputs( 278 | mel_outputs, stop_outputs, alignments) 279 | 280 | return mel_outputs, stop_outputs, alignments -------------------------------------------------------------------------------- /pre-train/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .utils import get_mask_from_lengths 5 | 6 | class ParrotLoss(nn.Module): 7 | def __init__(self, hparams): 8 | super(ParrotLoss, self).__init__() 9 | self.hidden_dim = hparams.encoder_embedding_dim 10 | self.ce_loss = hparams.ce_loss 11 | 12 | self.L1Loss = nn.L1Loss(reduction='none') 13 | self.MSELoss = nn.MSELoss(reduction='none') 14 | self.BCEWithLogitsLoss = nn.BCEWithLogitsLoss(reduction='none') 15 | self.CrossEntropyLoss = nn.CrossEntropyLoss(reduction='none') 16 | self.n_frames_per_step = hparams.n_frames_per_step_decoder 17 | self.eos = hparams.n_symbols 18 | self.predict_spectrogram = hparams.predict_spectrogram 19 | 20 | self.contr_w = hparams.contrastive_loss_w 21 | self.spenc_w = hparams.speaker_encoder_loss_w 22 | self.texcl_w = hparams.text_classifier_loss_w 23 | self.spadv_w = hparams.speaker_adversial_loss_w 24 | self.spcla_w = hparams.speaker_classifier_loss_w 25 | 26 | def parse_targets(self, targets, text_lengths): 27 | ''' 28 | text_target [batch_size, text_len] 29 | mel_target [batch_size, mel_bins, T] 30 | spc_target [batch_size, spc_bins, T] 31 | speaker_target [batch_size] 32 | stop_target [batch_size, T] 33 | ''' 34 | text_target, mel_target, spc_target, speaker_target, stop_target = targets 35 | 36 | B = stop_target.size(0) 37 | stop_target = stop_target.reshape(B, -1, self.n_frames_per_step) 38 | stop_target = stop_target[:, :, 0] 39 | 40 | padded = torch.tensor(text_target.data.new(B,1).zero_()) 41 | text_target = torch.cat((text_target, padded), dim=-1) 42 | 43 | # adding the ending token for target 44 | for bid in range(B): 45 | text_target[bid, text_lengths[bid].item()] = self.eos 46 | 47 | return text_target, mel_target, spc_target, speaker_target, stop_target 48 | 49 | def forward(self, model_outputs, targets, input_text, eps=1e-5): 50 | 51 | ''' 52 | predicted_mel [batch_size, mel_bins, T] 53 | predicted_stop [batch_size, T/r] 54 | alignment 55 | when input_text==True [batch_size, T/r, max_text_len] 56 | when input_text==False [batch_size, T/r, T/r] 57 | text_hidden [B, max_text_len, hidden_dim] 58 | mel_hidden [B, max_text_len, hidden_dim] 59 | text_logit_from_mel_hidden [B, max_text_len+1, n_symbols+1] 60 | speaker_logit_from_mel [B, n_speakers] 61 | speaker_logit_from_mel_hidden [B, max_text_len, n_speakers] 62 | text_lengths [B,] 63 | mel_lengths [B,] 64 | ''' 65 | predicted_mel, post_output, predicted_stop, alignments,\ 66 | text_hidden, mel_hidden, text_logit_from_mel_hidden, \ 67 | audio_seq2seq_alignments, \ 68 | speaker_logit_from_mel, speaker_logit_from_mel_hidden, \ 69 | text_lengths, mel_lengths = model_outputs 70 | 71 | text_target, mel_target, spc_target, speaker_target, stop_target = self.parse_targets(targets, text_lengths) 72 | 73 | ## get masks ## 74 | mel_mask = get_mask_from_lengths(mel_lengths, mel_target.size(2)).unsqueeze(1).expand(-1, mel_target.size(1), -1).float() 75 | spc_mask = get_mask_from_lengths(mel_lengths, mel_target.size(2)).unsqueeze(1).expand(-1, spc_target.size(1), -1).float() 76 | 77 | mel_step_lengths = torch.ceil(mel_lengths.float() / self.n_frames_per_step).long() 78 | stop_mask = get_mask_from_lengths(mel_step_lengths, 79 | int(mel_target.size(2)/self.n_frames_per_step)).float() # [B, T/r] 80 | text_mask = get_mask_from_lengths(text_lengths).float() 81 | text_mask_plus_one = get_mask_from_lengths(text_lengths + 1).float() 82 | 83 | # reconstruction loss # 84 | recon_loss = torch.sum(self.L1Loss(predicted_mel, mel_target) * mel_mask) / torch.sum(mel_mask) 85 | 86 | if self.predict_spectrogram: 87 | recon_loss_post = (self.L1Loss(post_output, spc_target) * spc_mask).sum() / spc_mask.sum() 88 | else: 89 | recon_loss_post = (self.L1Loss(post_output, mel_target) * mel_mask).sum() / torch.sum(mel_mask) 90 | 91 | stop_loss = torch.sum(self.BCEWithLogitsLoss(predicted_stop, stop_target) * stop_mask) / torch.sum(stop_mask) 92 | 93 | 94 | if self.contr_w == 0.: 95 | contrast_loss = torch.tensor(0.).cuda() 96 | else: 97 | # contrastive mask # 98 | contrast_mask1 = get_mask_from_lengths(text_lengths).unsqueeze(2).expand(-1, -1, mel_hidden.size(1)) # [B, text_len] -> [B, text_len, T/r] 99 | contrast_mask2 = get_mask_from_lengths(text_lengths).unsqueeze(1).expand(-1, text_hidden.size(1), -1) # [B, T/r] -> [B, text_len, T/r] 100 | contrast_mask = (contrast_mask1 & contrast_mask2).float() 101 | text_hidden_normed = text_hidden / (torch.norm(text_hidden, dim=2, keepdim=True) + eps) 102 | mel_hidden_normed = mel_hidden / (torch.norm(mel_hidden, dim=2, keepdim=True) + eps) 103 | 104 | # (x - y) ** 2 = x ** 2 + y ** 2 - 2xy 105 | distance_matrix_xx = torch.sum(text_hidden_normed ** 2, dim=2, keepdim=True) #[batch_size, text_len, 1] 106 | distance_matrix_yy = torch.sum(mel_hidden_normed ** 2, dim=2) 107 | distance_matrix_yy = distance_matrix_yy.unsqueeze(1) #[batch_size, 1, text_len] 108 | 109 | #[batch_size, text_len, text_len] 110 | distance_matrix_xy = torch.bmm(text_hidden_normed, torch.transpose(mel_hidden_normed, 1, 2)) 111 | distance_matrix = distance_matrix_xx + distance_matrix_yy - 2 * distance_matrix_xy 112 | 113 | TTEXT = distance_matrix.size(1) 114 | hard_alignments = torch.eye(TTEXT).cuda() 115 | contrast_loss = hard_alignments * distance_matrix + \ 116 | (1. - hard_alignments) * torch.max(1. - distance_matrix, torch.zeros_like(distance_matrix)) 117 | 118 | contrast_loss = torch.sum(contrast_loss * contrast_mask) / torch.sum(contrast_mask) 119 | 120 | n_speakers = speaker_logit_from_mel_hidden.size(2) 121 | TTEXT = speaker_logit_from_mel_hidden.size(1) 122 | n_symbols_plus_one = text_logit_from_mel_hidden.size(2) 123 | 124 | # speaker classification loss # 125 | speaker_encoder_loss = nn.CrossEntropyLoss()(speaker_logit_from_mel, speaker_target) 126 | _, predicted_speaker = torch.max(speaker_logit_from_mel,dim=1) 127 | speaker_encoder_acc = ((predicted_speaker == speaker_target).float()).sum() / float(speaker_target.size(0)) 128 | 129 | speaker_logit_flatten = speaker_logit_from_mel_hidden.reshape(-1, n_speakers) # -> [B* TTEXT, n_speakers] 130 | _, predicted_speaker = torch.max(speaker_logit_flatten, dim=1) 131 | speaker_target_flatten = speaker_target.unsqueeze(1).expand(-1, TTEXT).reshape(-1) 132 | speaker_classification_acc = ((predicted_speaker == speaker_target_flatten).float() * text_mask.reshape(-1)).sum() / text_mask.sum() 133 | loss = self.CrossEntropyLoss(speaker_logit_flatten, speaker_target_flatten) 134 | 135 | speaker_classification_loss = torch.sum(loss * text_mask.reshape(-1)) / torch.sum(text_mask) 136 | 137 | # text classification loss # 138 | text_logit_flatten = text_logit_from_mel_hidden.reshape(-1, n_symbols_plus_one) 139 | text_target_flatten = text_target.reshape(-1) 140 | _, predicted_text = torch.max(text_logit_flatten, dim=1) 141 | text_classification_acc = ((predicted_text == text_target_flatten).float()*text_mask_plus_one.reshape(-1)).sum()/text_mask_plus_one.sum() 142 | loss = self.CrossEntropyLoss(text_logit_flatten, text_target_flatten) 143 | text_classification_loss = torch.sum(loss * text_mask_plus_one.reshape(-1)) / torch.sum(text_mask_plus_one) 144 | 145 | # speaker adversival loss # 146 | flatten_target = 1. / n_speakers * torch.ones_like(speaker_logit_flatten) 147 | loss = self.MSELoss(F.softmax(speaker_logit_flatten, dim=1), flatten_target) 148 | mask = text_mask.unsqueeze(2).expand(-1,-1, n_speakers).reshape(-1, n_speakers) 149 | 150 | if self.ce_loss: 151 | speaker_adversial_loss = - speaker_classification_loss 152 | else: 153 | speaker_adversial_loss = torch.sum(loss * mask) / torch.sum(mask) 154 | 155 | loss_list = [recon_loss, recon_loss_post, stop_loss, 156 | contrast_loss, speaker_encoder_loss, speaker_classification_loss, 157 | text_classification_loss, speaker_adversial_loss] 158 | 159 | acc_list = [speaker_encoder_acc, speaker_classification_acc, text_classification_acc] 160 | 161 | 162 | combined_loss1 = recon_loss + recon_loss_post + stop_loss + self.contr_w * contrast_loss + \ 163 | self.spenc_w * speaker_encoder_loss + self.texcl_w * text_classification_loss + \ 164 | self.spadv_w * speaker_adversial_loss 165 | 166 | combined_loss2 = self.spcla_w * speaker_classification_loss 167 | 168 | return loss_list, acc_list, combined_loss1, combined_loss2 169 | 170 | -------------------------------------------------------------------------------- /pre-train/model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from math import sqrt 5 | from .utils import to_gpu 6 | from .decoder import Decoder 7 | from .layers import SpeakerClassifier, SpeakerEncoder, AudioSeq2seq, TextEncoder, PostNet, MergeNet 8 | 9 | 10 | class Parrot(nn.Module): 11 | def __init__(self, hparams): 12 | super(Parrot, self).__init__() 13 | 14 | #print hparams 15 | # plus 16 | self.embedding = nn.Embedding( 17 | hparams.n_symbols + 1, hparams.symbols_embedding_dim) 18 | std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) 19 | val = sqrt(3.0) * std 20 | 21 | self.sos = hparams.n_symbols 22 | 23 | self.embedding.weight.data.uniform_(-val, val) 24 | 25 | self.text_encoder = TextEncoder(hparams) 26 | 27 | self.audio_seq2seq = AudioSeq2seq(hparams) 28 | 29 | self.merge_net = MergeNet(hparams) 30 | 31 | self.speaker_encoder = SpeakerEncoder(hparams) 32 | 33 | self.speaker_classifier = SpeakerClassifier(hparams) 34 | 35 | self.decoder = Decoder(hparams) 36 | 37 | self.postnet = PostNet(hparams) 38 | 39 | self.spemb_input = hparams.spemb_input 40 | 41 | def grouped_parameters(self,): 42 | 43 | params_group1 = [p for p in self.embedding.parameters()] 44 | params_group1.extend([p for p in self.text_encoder.parameters()]) 45 | params_group1.extend([p for p in self.audio_seq2seq.parameters()]) 46 | params_group1.extend([p for p in self.speaker_encoder.parameters()]) 47 | params_group1.extend([p for p in self.merge_net.parameters()]) 48 | params_group1.extend([p for p in self.decoder.parameters()]) 49 | params_group1.extend([p for p in self.postnet.parameters()]) 50 | 51 | return params_group1, [p for p in self.speaker_classifier.parameters()] 52 | 53 | def parse_batch(self, batch): 54 | text_input_padded, mel_padded, spc_padded, speaker_id, \ 55 | text_lengths, mel_lengths, stop_token_padded = batch 56 | 57 | text_input_padded = to_gpu(text_input_padded).long() 58 | mel_padded = to_gpu(mel_padded).float() 59 | spc_padded = to_gpu(spc_padded).float() 60 | speaker_id = to_gpu(speaker_id).long() 61 | text_lengths = to_gpu(text_lengths).long() 62 | mel_lengths = to_gpu(mel_lengths).long() 63 | stop_token_padded = to_gpu(stop_token_padded).float() 64 | 65 | return ((text_input_padded, mel_padded, text_lengths, mel_lengths), 66 | (text_input_padded, mel_padded, spc_padded, speaker_id, stop_token_padded)) 67 | 68 | 69 | def forward(self, inputs, input_text): 70 | ''' 71 | text_input_padded [batch_size, max_text_len] 72 | mel_padded [batch_size, mel_bins, max_mel_len] 73 | text_lengths [batch_size] 74 | mel_lengths [batch_size] 75 | 76 | # 77 | predicted_mel [batch_size, mel_bins, T] 78 | predicted_stop [batch_size, T/r] 79 | alignment input_text==True [batch_size, T/r, max_text_len] or input_text==False [batch_size, T/r, T/r] 80 | text_hidden [B, max_text_len, hidden_dim] 81 | mel_hidden [B, T/r, hidden_dim] 82 | spearker_logit_from_mel [B, n_speakers] 83 | speaker_logit_from_mel_hidden [B, T/r, n_speakers] 84 | text_logit_from_mel_hidden [B, T/r, n_symbols] 85 | 86 | ''' 87 | 88 | text_input_padded, mel_padded, text_lengths, mel_lengths = inputs 89 | 90 | text_input_embedded = self.embedding(text_input_padded.long()).transpose(1, 2) # -> [B, text_embedding_dim, max_text_len] 91 | text_hidden = self.text_encoder(text_input_embedded, text_lengths) # -> [B, max_text_len, hidden_dim] 92 | 93 | B = text_input_padded.size(0) 94 | start_embedding = Variable(text_input_padded.data.new(B,).fill_(self.sos)) 95 | start_embedding = self.embedding(start_embedding) 96 | 97 | # -> [B, n_speakers], [B, speaker_embedding_dim] 98 | speaker_logit_from_mel, speaker_embedding = self.speaker_encoder(mel_padded, mel_lengths) 99 | 100 | if self.spemb_input: 101 | T = mel_padded.size(2) 102 | audio_input = torch.cat([mel_padded, 103 | speaker_embedding.detach().unsqueeze(2).expand(-1, -1, T)], 1) 104 | else: 105 | audio_input = mel_padded 106 | 107 | audio_seq2seq_hidden, audio_seq2seq_logit, audio_seq2seq_alignments = self.audio_seq2seq( 108 | audio_input, mel_lengths, text_input_embedded, start_embedding) 109 | audio_seq2seq_hidden= audio_seq2seq_hidden[:,:-1, :] # -> [B, text_len, hidden_dim] 110 | 111 | 112 | speaker_logit_from_mel_hidden = self.speaker_classifier(audio_seq2seq_hidden) # -> [B, text_len, n_speakers] 113 | 114 | if input_text: 115 | hidden = self.merge_net(text_hidden, text_lengths) 116 | else: 117 | hidden = self.merge_net(audio_seq2seq_hidden, text_lengths) 118 | 119 | L = hidden.size(1) 120 | hidden = torch.cat([hidden, speaker_embedding.detach().unsqueeze(1).expand(-1, L, -1)], -1) 121 | 122 | predicted_mel, predicted_stop, alignments = self.decoder(hidden, mel_padded, text_lengths) 123 | 124 | post_output = self.postnet(predicted_mel) 125 | 126 | outputs = [predicted_mel, post_output, predicted_stop, alignments, 127 | text_hidden, audio_seq2seq_hidden, audio_seq2seq_logit, audio_seq2seq_alignments, 128 | speaker_logit_from_mel, speaker_logit_from_mel_hidden, 129 | text_lengths, mel_lengths] 130 | 131 | return outputs 132 | 133 | 134 | def inference(self, inputs, input_text, mel_reference, beam_width): 135 | ''' 136 | decode the audio sequence from input 137 | inputs x 138 | input_text True or False 139 | mel_reference [1, mel_bins, T] 140 | ''' 141 | text_input_padded, mel_padded, text_lengths, mel_lengths = inputs 142 | text_input_embedded = self.embedding(text_input_padded.long()).transpose(1, 2) 143 | text_hidden = self.text_encoder.inference(text_input_embedded) 144 | 145 | B = text_input_padded.size(0) # B should be 1 146 | start_embedding = Variable(text_input_padded.data.new(B,).fill_(self.sos)) 147 | start_embedding = self.embedding(start_embedding) # [1, embedding_dim] 148 | 149 | #-> [B, text_len+1, hidden_dim] [B, text_len+1, n_symbols] [B, text_len+1, T/r] 150 | speaker_id, speaker_embedding = self.speaker_encoder.inference(mel_reference) 151 | 152 | if self.spemb_input: 153 | T = mel_padded.size(2) 154 | audio_input = torch.cat([mel_padded, 155 | speaker_embedding.detach().unsqueeze(2).expand(-1, -1, T)], 1) 156 | else: 157 | audio_input = mel_padded 158 | 159 | audio_seq2seq_hidden, audio_seq2seq_phids, audio_seq2seq_alignments = self.audio_seq2seq.inference_beam( 160 | audio_input, start_embedding, self.embedding, beam_width=beam_width) 161 | audio_seq2seq_hidden= audio_seq2seq_hidden[:,:-1, :] # -> [B, text_len, hidden_dim] 162 | 163 | # -> [B, n_speakers], [B, speaker_embedding_dim] 164 | 165 | if input_text: 166 | hidden = self.merge_net.inference(text_hidden) 167 | else: 168 | hidden = self.merge_net.inference(audio_seq2seq_hidden) 169 | 170 | L = hidden.size(1) 171 | hidden = torch.cat([hidden, speaker_embedding.detach().unsqueeze(1).expand(-1, L, -1)], -1) 172 | 173 | predicted_mel, predicted_stop, alignments = self.decoder.inference(hidden) 174 | 175 | post_output = self.postnet(predicted_mel) 176 | 177 | return (predicted_mel, post_output, predicted_stop, alignments, 178 | text_hidden, audio_seq2seq_hidden, audio_seq2seq_phids, audio_seq2seq_alignments, 179 | speaker_id) 180 | 181 | 182 | -------------------------------------------------------------------------------- /pre-train/model/penalties.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | 5 | class PenaltyBuilder(object): 6 | """ 7 | Returns the Length and Coverage Penalty function for Beam Search. 8 | Args: 9 | length_pen (str): option name of length pen 10 | cov_pen (str): option name of cov pen 11 | """ 12 | 13 | def __init__(self, cov_pen, length_pen): 14 | self.length_pen = length_pen 15 | self.cov_pen = cov_pen 16 | 17 | def coverage_penalty(self): 18 | if self.cov_pen == "wu": 19 | return self.coverage_wu 20 | elif self.cov_pen == "summary": 21 | return self.coverage_summary 22 | else: 23 | return self.coverage_none 24 | 25 | def length_penalty(self): 26 | if self.length_pen == "wu": 27 | return self.length_wu 28 | elif self.length_pen == "avg": 29 | return self.length_average 30 | else: 31 | return self.length_none 32 | 33 | """ 34 | Below are all the different penalty terms implemented so far 35 | """ 36 | 37 | def coverage_wu(self, beam, cov, beta=0.): 38 | """ 39 | NMT coverage re-ranking score from 40 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 41 | """ 42 | penalty = -torch.min(cov, cov.clone().fill_(1.0)).log().sum(1) 43 | return beta * penalty 44 | 45 | def coverage_summary(self, beam, cov, beta=0.): 46 | """ 47 | Our summary penalty. 48 | """ 49 | penalty = torch.max(cov, cov.clone().fill_(1.0)).sum(1) 50 | penalty -= cov.size(1) 51 | return beta * penalty 52 | 53 | def coverage_none(self, beam, cov, beta=0.): 54 | """ 55 | returns zero as penalty 56 | """ 57 | return beam.scores.clone().fill_(0.0) 58 | 59 | def length_wu(self, beam, logprobs, alpha=0.): 60 | """ 61 | NMT length re-ranking score from 62 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 63 | """ 64 | 65 | modifier = (((5 + len(beam.next_ys)) ** alpha) / 66 | ((5 + 1) ** alpha)) 67 | return (logprobs / modifier) 68 | 69 | def length_average(self, beam, logprobs, alpha=0.): 70 | """ 71 | Returns the average probability of tokens in a sequence. 72 | """ 73 | return logprobs / len(beam.next_ys) 74 | 75 | def length_none(self, beam, logprobs, alpha=0., beta=0.): 76 | """ 77 | Returns unmodified scores. 78 | """ 79 | return logprobs -------------------------------------------------------------------------------- /pre-train/model/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def gcd(a,b): 6 | a, b = (a, b) if a >=b else (b, a) 7 | if a%b == 0: 8 | return b 9 | else : 10 | return gcd(b,a%b) 11 | 12 | def lcm(a,b): 13 | return a*b//gcd(a,b) 14 | 15 | 16 | if __name__ == "__main__": 17 | print(lcm(3,2)) 18 | 19 | def get_mask_from_lengths(lengths, max_len=None): 20 | if max_len is None: 21 | max_len = torch.max(lengths).item() 22 | ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) 23 | #print ids 24 | mask = (ids < lengths.unsqueeze(1)).byte() 25 | return mask 26 | 27 | def to_gpu(x): 28 | x = x.contiguous() 29 | 30 | if torch.cuda.is_available(): 31 | x = x.cuda(non_blocking=True) 32 | return torch.autograd.Variable(x) 33 | 34 | def test_mask(): 35 | lengths = torch.IntTensor([3,5,4]) 36 | print(torch.ceil(lengths.float() / 2)) 37 | 38 | data = torch.FloatTensor(3, 5, 2) # [B, T, D] 39 | data.fill_(1.) 40 | m = get_mask_from_lengths(lengths.cuda(), data.size(1)) 41 | print(m) 42 | m = m.unsqueeze(2).expand(-1,-1,data.size(2)).float() 43 | print(m) 44 | 45 | print(torch.sum(data.cuda() * m) / torch.sum(m)) 46 | 47 | 48 | def test_loss(): 49 | data1 = torch.FloatTensor(3, 5, 2) 50 | data1.fill_(1.) 51 | data2 = torch.FloatTensor(3, 5, 2) 52 | data2.fill_(2.) 53 | data2[0,0,0] = 1000 54 | 55 | l = torch.nn.L1Loss(reduction='none')(data1,data2) 56 | print(l) 57 | 58 | 59 | #if __name__ == '__main__': 60 | # test_mask() -------------------------------------------------------------------------------- /pre-train/multiproc.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import sys 4 | import subprocess 5 | 6 | argslist = list(sys.argv)[1:] 7 | num_gpus = torch.cuda.device_count() 8 | argslist.append('--n_gpus={}'.format(num_gpus)) 9 | workers = [] 10 | job_id = time.strftime("%Y_%m_%d-%H%M%S") 11 | argslist.append("--group_name=group_{}".format(job_id)) 12 | 13 | for i in range(num_gpus): 14 | argslist.append('--rank={}'.format(i)) 15 | stdout = None if i == 0 else open("logs/{}_GPU_{}.log".format(job_id, i), 16 | "w") 17 | print(argslist) 18 | p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout) 19 | workers.append(p) 20 | argslist = argslist[:-1] 21 | 22 | for p in workers: 23 | p.wait() 24 | -------------------------------------------------------------------------------- /pre-train/plotting_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import matplotlib.pylab as plt 4 | import numpy as np 5 | 6 | 7 | def save_figure_to_numpy(fig): 8 | # save it to a numpy array. 9 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 10 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 11 | return data 12 | 13 | def plot_alignment(alignment, fn): 14 | # [4, encoder_step, decoder_step] 15 | fig, axes = plt.subplots(2, 2) 16 | for i in range(2): 17 | for j in range(2): 18 | g = axes[i][j].imshow(alignment[i*2+j,:,:].T, 19 | aspect='auto', origin='lower', 20 | interpolation='none') 21 | plt.colorbar(g, ax=axes[i][j]) 22 | 23 | plt.savefig(fn) 24 | plt.close() 25 | return fn 26 | 27 | 28 | def plot_alignment_to_numpy(alignment, info=None): 29 | fig, ax = plt.subplots(figsize=(6, 4)) 30 | im = ax.imshow(alignment, aspect='auto', origin='lower', 31 | interpolation='none') 32 | fig.colorbar(im, ax=ax) 33 | xlabel = 'Decoder timestep' 34 | if info is not None: 35 | xlabel += '\n\n' + info 36 | plt.xlabel(xlabel) 37 | plt.ylabel('Encoder timestep') 38 | plt.tight_layout() 39 | 40 | fig.canvas.draw() 41 | data = save_figure_to_numpy(fig) 42 | plt.close() 43 | return data 44 | 45 | 46 | def plot_spectrogram_to_numpy(spectrogram): 47 | fig, ax = plt.subplots(figsize=(12, 3)) 48 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 49 | interpolation='none') 50 | plt.colorbar(im, ax=ax) 51 | plt.xlabel("Frames") 52 | plt.ylabel("Channels") 53 | plt.tight_layout() 54 | 55 | fig.canvas.draw() 56 | data = save_figure_to_numpy(fig) 57 | plt.close() 58 | return data 59 | 60 | 61 | def plot_gate_outputs_to_numpy(gate_targets, gate_outputs): 62 | fig, ax = plt.subplots(figsize=(12, 3)) 63 | ax.scatter(list(range(len(gate_targets))), gate_targets, alpha=0.5, 64 | color='green', marker='+', s=1, label='target') 65 | ax.scatter(list(range(len(gate_outputs))), gate_outputs, alpha=0.5, 66 | color='red', marker='.', s=1, label='predicted') 67 | 68 | plt.xlabel("Frames (Green target, Red predicted)") 69 | plt.ylabel("Gate State") 70 | plt.tight_layout() 71 | 72 | fig.canvas.draw() 73 | data = save_figure_to_numpy(fig) 74 | plt.close() 75 | return data 76 | -------------------------------------------------------------------------------- /pre-train/reader/__init__.py: -------------------------------------------------------------------------------- 1 | from .reader import TextMelIDLoader, TextMelIDCollate 2 | from .symbols import id2sp, id2ph -------------------------------------------------------------------------------- /pre-train/reader/extract_features.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import glob 4 | import os 5 | from multiprocessing import Pool, cpu_count 6 | import sys 7 | 8 | def extract_mel_spec(filename): 9 | ''' 10 | extract and save both log-linear and log-Mel spectrograms. 11 | saved spec shape [n_frames, 1025] 12 | saved mel shape [n_frames, 80] 13 | ''' 14 | y, sample_rate = librosa.load(filename) 15 | 16 | spec = librosa.core.stft(y=y, 17 | n_fft=2048, 18 | hop_length=200, 19 | win_length=800, 20 | window='hann', 21 | center=True, 22 | pad_mode='reflect') 23 | spec= librosa.magphase(spec)[0] 24 | log_spectrogram = np.log(spec).astype(np.float32) 25 | 26 | mel_spectrogram = librosa.feature.melspectrogram(S=spec, 27 | sr=sample_rate, 28 | n_mels=80, 29 | power=1.0, #actually not used given "S=spec" 30 | fmin=0.0, 31 | fmax=None, 32 | htk=False, 33 | norm=1 34 | ) 35 | log_mel_spectrogram = np.log(mel_spectrogram).astype(np.float32) 36 | 37 | np.save(file=filename.replace(".wav", ".spec"), arr=log_spectrogram.T) 38 | np.save(file=filename.replace(".wav", ".mel"), arr=log_mel_spectrogram.T) 39 | 40 | 41 | def extract_phonemes(filename): 42 | from phonemizer.phonemize import phonemize 43 | from phonemizer.backend import FestivalBackend 44 | from phonemizer.separator import Separator 45 | 46 | with open(filename) as f: 47 | text=f.read() 48 | phones = phonemize(text, 49 | language='en-us', 50 | backend='festival', 51 | separator=Separator(phone=' ', 52 | syllable='', 53 | word='') 54 | ) 55 | 56 | with open(filename.replace(".txt", ".phones"), "w") as outfile: 57 | print(phones, file=outfile) 58 | 59 | def extract_dir(root, kind): 60 | if kind =="audio": 61 | extraction_function=extract_mel_spec 62 | ext=".wav" 63 | elif kind =="text": 64 | extraction_function=extract_phonemes 65 | ext=".txt" 66 | else: 67 | print("ERROR: invalid args") 68 | sys.exit(1) 69 | if not os.path.isdir(root): 70 | print("ERROR: invalid args") 71 | sys.exit(1) 72 | 73 | # traverse over all subdirs of the provided dir, and find 74 | # only files with the proper extension 75 | abs_paths=[] 76 | for dirpath, _, filenames in os.walk(root): 77 | for f in filenames: 78 | abs_path = os.path.abspath(os.path.join(dirpath, f)) 79 | if abs_path.endswith(ext): 80 | abs_paths.append(abs_path) 81 | 82 | pool = Pool(cpu_count()) 83 | pool.map(extraction_function,abs_paths) 84 | 85 | #estimate and save mean std statistics in root dir. 86 | estimate_mean_std(root) 87 | 88 | 89 | def estimate_mean_std(root, num=2000): 90 | ''' 91 | use the training data for estimating mean and standard deviation 92 | use $num utterances to avoid out of memory 93 | ''' 94 | specs, mels = [], [] 95 | counter_sp, counter_mel = 0, 0 96 | for dirpath, _, filenames in os.walk(root): 97 | for f in filenames: 98 | if f.endswith('.spec.npy') and counter_sp= 1000: 34 | continue 35 | file_path_list.append(path) 36 | 37 | if shuffle: 38 | random.seed(1234) 39 | random.shuffle(file_path_list) 40 | 41 | self.file_path_list = file_path_list 42 | self.mel_mean_std = np.float32(np.load(mean_std_file)) 43 | self.spc_mean_std = np.float32(np.load(mean_std_file.replace('mel', 'spec'))) 44 | 45 | def get_path_id(self, path): 46 | # Custom this function to obtain paths and speaker id 47 | # Deduce filenames 48 | spec_path = path 49 | text_path = path.replace('spec', 'text').replace('npy', 'txt').replace('log-', '') 50 | mel_path = path.replace('spec', 'mel') 51 | speaker_id = path.split('/')[-2] 52 | 53 | return mel_path, spec_path, text_path, speaker_id 54 | 55 | 56 | def get_text_mel_id_pair(self, path): 57 | ''' 58 | You should Modify this function to read your own data. 59 | 60 | Returns: 61 | 62 | object: dimensionality 63 | ----------------------- 64 | text_input: [len_text] 65 | mel: [mel_bin, len_mel] 66 | mel: [spc_bin, len_spc] 67 | speaker_id: [1] 68 | ''' 69 | 70 | mel_path, spec_path, text_path, speaker_id = self.get_path_id(path) 71 | # Load data from disk 72 | text_input = self.get_text(text_path) 73 | mel = np.load(mel_path) 74 | spc = np.load(spec_path) 75 | # Normalize audio 76 | mel = (mel - self.mel_mean_std[0])/ self.mel_mean_std[1] 77 | spc = (spc - self.spc_mean_std[0]) / self.spc_mean_std[1] 78 | # Format for pytorch 79 | text_input = torch.LongTensor(text_input) 80 | mel = torch.from_numpy(np.transpose(mel)) 81 | spc = torch.from_numpy(np.transpose(spc)) 82 | speaker_id = torch.LongTensor([sp2id[speaker_id]]) 83 | 84 | return (text_input, mel, spc, speaker_id) 85 | 86 | def get_text(self,text_path): 87 | ''' 88 | Returns: 89 | 90 | text_input: a list of phoneme IDs corresponding 91 | to the transcript of one utterance 92 | ''' 93 | text = read_text(text_path) 94 | text_input = [] 95 | 96 | for start, end, ph in text: 97 | text_input.append(ph2id[ph]) 98 | 99 | return text_input 100 | 101 | def __getitem__(self, index): 102 | return self.get_text_mel_id_pair(self.file_path_list[index]) 103 | 104 | def __len__(self): 105 | return len(self.file_path_list) 106 | 107 | 108 | class TextMelIDCollate(): 109 | 110 | def __init__(self, n_frames_per_step=2): 111 | self.n_frames_per_step = n_frames_per_step 112 | 113 | def __call__(self, batch): 114 | ''' 115 | batch is list of (text_input, mel, spc, speaker_id) 116 | ''' 117 | text_lengths = torch.IntTensor([len(x[0]) for x in batch]) 118 | mel_lengths = torch.IntTensor([x[1].size(1) for x in batch]) 119 | mel_bin = batch[0][1].size(0) 120 | spc_bin = batch[0][2].size(0) 121 | 122 | max_text_len = torch.max(text_lengths).item() 123 | max_mel_len = torch.max(mel_lengths).item() 124 | if max_mel_len % self.n_frames_per_step != 0: 125 | max_mel_len += self.n_frames_per_step - max_mel_len % self.n_frames_per_step 126 | assert max_mel_len % self.n_frames_per_step == 0 127 | 128 | text_input_padded = torch.LongTensor(len(batch), max_text_len) 129 | mel_padded = torch.FloatTensor(len(batch), mel_bin, max_mel_len) 130 | spc_padded = torch.FloatTensor(len(batch), spc_bin, max_mel_len) 131 | 132 | speaker_id = torch.LongTensor(len(batch)) 133 | stop_token_padded = torch.FloatTensor(len(batch), max_mel_len) 134 | 135 | text_input_padded.zero_() 136 | mel_padded.zero_() 137 | spc_padded.zero_() 138 | speaker_id.zero_() 139 | stop_token_padded.zero_() 140 | 141 | for i in range(len(batch)): 142 | text = batch[i][0] 143 | mel = batch[i][1] 144 | spc = batch[i][2] 145 | 146 | text_input_padded[i,:text.size(0)] = text 147 | mel_padded[i, :, :mel.size(1)] = mel 148 | spc_padded[i, :, :spc.size(1)] = spc 149 | speaker_id[i] = batch[i][3][0] 150 | # make sure the downsampled stop_token_padded have the last eng flag 1. 151 | stop_token_padded[i, mel.size(1)-self.n_frames_per_step:] = 1 152 | 153 | return text_input_padded, mel_padded, spc_padded, speaker_id, \ 154 | text_lengths, mel_lengths, stop_token_padded 155 | 156 | -------------------------------------------------------------------------------- /pre-train/reader/symbols.py: -------------------------------------------------------------------------------- 1 | phone_list = ['pau', 'iy', 'aa', 'ch', 'ae', 'eh', 2 | 'ah', 'ao', 'ih', 'ey', 'aw', 3 | 'ay', 'ax', 'er', 'ng', 4 | 'sh', 'th', 'uh', 'zh', 'oy', 5 | 'dh', 'y', 'hh', 'jh', 'b', 6 | 'd', 'g', 'f', 'k', 'm', 7 | 'l', 'n', 'p', 's', 'r', 8 | 't', 'w', 'v', 'ow', 'z', 9 | 'uw', 'SOS/EOS'] 10 | 11 | seen_speakers = ['p336', 'p240', 'p262', 'p333', 'p297', 'p339', 'p276', 'p269', 'p303', 'p260', 'p250', 'p345', 'p305', 'p283', 'p277', 'p302', 'p280', 'p295', 'p245', 'p227', 'p257', 'p282', 'p259', 'p311', 'p301', 'p265', 'p270', 'p329', 'p362', 'p343', 'p246', 'p247', 'p351', 'p263', 'p363', 'p249', 'p231', 'p292', 'p304', 'p347', 'p314', 'p244', 'p261', 'p298', 'p272', 'p308', 'p299', 'p234', 'p268', 'p271', 'p316', 'p287', 'p318', 'p264', 'p313', 'p236', 'p238', 'p334', 'p312', 'p230', 'p253', 'p323', 'p361', 'p275', 'p252', 'p374', 'p286', 'p274', 'p254', 'p310', 'p306', 'p294', 'p326', 'p225', 'p255', 'p293', 'p278', 'p266', 'p229', 'p335', 'p281', 'p307', 'p256', 'p243', 'p364', 'p239', 'p232', 'p258', 'p267', 'p317', 'p284', 'p300', 'p288', 'p341', 'p340', 'p279', 'p330', 'p360', 'p285'] 12 | 13 | ph2id = {ph:i for i, ph in enumerate(phone_list)} 14 | id2ph = {i:ph for i, ph in enumerate(phone_list)} 15 | sp2id = {sp:i for i, sp in enumerate(seen_speakers)} 16 | id2sp = {i:sp for i, sp in enumerate(seen_speakers)} 17 | -------------------------------------------------------------------------------- /pre-train/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # you can set the hparams by using --hparams=xxx 4 | CUDA_VISIBLE_DEVICES=3 python train.py -l logdir \ 5 | -o outdir --n_gpus=1 --hparams=speaker_adversial_loss_w=20.,ce_loss=False,speaker_classifier_loss_w=0.1,contrastive_loss_w=30. -------------------------------------------------------------------------------- /pre-train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import math 5 | from numpy import finfo 6 | import numpy as np 7 | 8 | import torch 9 | from distributed import apply_gradient_allreduce 10 | import torch.distributed as dist 11 | from torch.utils.data.distributed import DistributedSampler 12 | from torch.utils.data import DataLoader 13 | 14 | from model import Parrot, ParrotLoss, lcm 15 | from reader import TextMelIDLoader, TextMelIDCollate 16 | from logger import ParrotLogger 17 | from hparams import create_hparams 18 | 19 | 20 | def batchnorm_to_float(module): 21 | """Converts batch norm modules to FP32""" 22 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 23 | module.float() 24 | for child in module.children(): 25 | batchnorm_to_float(child) 26 | return module 27 | 28 | 29 | def reduce_tensor(tensor, n_gpus): 30 | rt = tensor.clone() 31 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 32 | rt /= n_gpus 33 | return rt 34 | 35 | 36 | def init_distributed(hparams, n_gpus, rank, group_name): 37 | assert torch.cuda.is_available(), "Distributed mode requires CUDA." 38 | print("Initializing Distributed") 39 | 40 | # Set cuda device so everything is done on the right GPU. 41 | torch.cuda.set_device(rank % torch.cuda.device_count()) 42 | 43 | # Initialize distributed communication 44 | dist.init_process_group( 45 | backend=hparams.dist_backend, init_method=hparams.dist_url, 46 | world_size=n_gpus, rank=rank, group_name=group_name) 47 | 48 | print("Done initializing distributed") 49 | 50 | 51 | def prepare_dataloaders(hparams): 52 | # Get data, data loaders and collate function ready 53 | trainset = TextMelIDLoader(hparams.training_list, hparams.mel_mean_std) 54 | valset = TextMelIDLoader(hparams.validation_list, hparams.mel_mean_std) 55 | collate_fn = TextMelIDCollate(lcm(hparams.n_frames_per_step_encoder, 56 | hparams.n_frames_per_step_decoder)) 57 | 58 | train_sampler = DistributedSampler(trainset) \ 59 | if hparams.distributed_run else None 60 | 61 | train_loader = DataLoader(trainset, num_workers=1, shuffle=True, 62 | sampler=train_sampler, 63 | batch_size=hparams.batch_size, pin_memory=False, 64 | drop_last=True, collate_fn=collate_fn) 65 | return train_loader, valset, collate_fn 66 | 67 | 68 | def prepare_directories_and_logger(output_directory, log_directory, rank): 69 | if rank == 0: 70 | if not os.path.isdir(output_directory): 71 | os.makedirs(output_directory) 72 | os.chmod(output_directory, 0o775) 73 | logger = ParrotLogger(os.path.join(output_directory, log_directory)) 74 | else: 75 | logger = None 76 | return logger 77 | 78 | 79 | def load_model(hparams): 80 | model = Parrot(hparams).cuda() 81 | if hparams.distributed_run: 82 | model = apply_gradient_allreduce(model) 83 | 84 | return model 85 | 86 | 87 | def warm_start_model(checkpoint_path, model): 88 | assert os.path.isfile(checkpoint_path) 89 | print(("Warm starting model from checkpoint '{}'".format(checkpoint_path))) 90 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 91 | model.load_state_dict(checkpoint_dict['state_dict']) 92 | return model 93 | 94 | 95 | def load_checkpoint(checkpoint_path, model, optimizer_main, optimizer_sc): 96 | assert os.path.isfile(checkpoint_path) 97 | print(("Loading checkpoint '{}'".format(checkpoint_path))) 98 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 99 | model.load_state_dict(checkpoint_dict['state_dict']) 100 | optimizer_main.load_state_dict(checkpoint_dict['optimizer_main']) 101 | optimizer_sc.load_state_dict(checkpoint_dict['optimizer_sc']) 102 | learning_rate = checkpoint_dict['learning_rate'] 103 | iteration = checkpoint_dict['iteration'] 104 | print(("Loaded checkpoint '{}' from iteration {}" .format( 105 | checkpoint_path, iteration))) 106 | return model, optimizer_main, optimizer_sc, learning_rate, iteration 107 | 108 | 109 | def save_checkpoint(model, optimizer_main, optimizer_sc, learning_rate, iteration, filepath): 110 | print(("Saving model and optimizer state at iteration {} to {}".format( 111 | iteration, filepath))) 112 | torch.save({'iteration': iteration, 113 | 'state_dict': model.state_dict(), 114 | 'optimizer_main': optimizer_main.state_dict(), 115 | 'optimizer_sc': optimizer_sc.state_dict(), 116 | 'learning_rate': learning_rate}, filepath) 117 | 118 | 119 | def validate(model, criterion, valset, iteration, batch_size, n_gpus, 120 | collate_fn, logger, distributed_run, rank): 121 | """Handles all the validation scoring and printing""" 122 | model.eval() 123 | with torch.no_grad(): 124 | val_sampler = DistributedSampler(valset) if distributed_run else None 125 | val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1, 126 | shuffle=False, batch_size=batch_size, 127 | drop_last=True, 128 | pin_memory=False, collate_fn=collate_fn) 129 | 130 | val_loss_tts, val_loss_vc = 0.0, 0.0 131 | reduced_val_tts_losses, reduced_val_vc_losses = np.zeros([8], dtype=np.float32), np.zeros([8], dtype=np.float32) 132 | reduced_val_tts_acces, reduced_val_vc_acces = np.zeros([3], dtype=np.float32), np.zeros([3], dtype=np.float32) 133 | 134 | for i, batch in enumerate(val_loader): 135 | 136 | x, y = model.parse_batch(batch) 137 | 138 | if i%2 == 0: 139 | y_pred = model(x, True) 140 | else: 141 | y_pred = model(x, False) 142 | 143 | losses, acces, l_main, l_sc = criterion(y_pred, y, False) 144 | if distributed_run: 145 | reduced_val_losses = [] 146 | reduced_val_acces = [] 147 | 148 | for l in losses: 149 | reduced_val_losses.append(reduce_tensor(l.data, n_gpus).item()) 150 | for a in acces: 151 | reduced_val_acces.append(reduce_tensor(a.data, n_gpus).item()) 152 | 153 | l_main = reduce_tensor(l_main.data, n_gpus).item() 154 | l_sc = reduce_tensor(l_sc.data, n_gpus).item() 155 | else: 156 | reduced_val_losses = [l.item() for l in losses] 157 | reduced_val_acces = [a.item() for a in acces] 158 | l_main = l_main.item() 159 | l_sc = l_sc.item() 160 | 161 | if i%2 == 0: 162 | val_loss_tts += l_main + l_sc 163 | y_tts = y 164 | y_tts_pred = y_pred 165 | reduced_val_tts_losses += np.array(reduced_val_losses) 166 | reduced_val_tts_acces += np.array(reduced_val_acces) 167 | else: 168 | val_loss_vc += l_main + l_sc 169 | y_vc = y 170 | y_vc_pred = y_pred 171 | reduced_val_vc_losses += np.array(reduced_val_losses) 172 | reduced_val_vc_acces += np.array(reduced_val_acces) 173 | 174 | if i % 2 == 0: 175 | num_tts = i / 2 + 1 176 | num_vc = i / 2 177 | else: 178 | num_tts = (i + 1) / 2 179 | num_vc = (i + 1) / 2 180 | 181 | val_loss_tts = val_loss_tts / num_tts 182 | val_loss_vc = val_loss_vc / num_vc 183 | reduced_val_tts_acces = reduced_val_tts_acces / num_tts 184 | reduced_val_vc_acces = reduced_val_vc_acces / num_vc 185 | reduced_val_tts_losses = reduced_val_tts_losses / num_tts 186 | reduced_val_vc_losses = reduced_val_vc_losses / num_vc 187 | 188 | model.train() 189 | if rank == 0: 190 | print(("Validation loss {}: TTS {:9f} VC {:9f}".format(iteration, val_loss_tts, val_loss_vc))) 191 | logger.log_validation(val_loss_tts, reduced_val_tts_losses, reduced_val_tts_acces, model, y_tts, y_tts_pred, iteration, 'tts') 192 | logger.log_validation(val_loss_vc, reduced_val_vc_losses, reduced_val_vc_acces, model, y_vc, y_vc_pred, iteration, 'vc') 193 | 194 | 195 | 196 | def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, 197 | rank, group_name, hparams): 198 | 199 | """Training and validation logging results to tensorboard and stdout 200 | Params 201 | ------ 202 | output_directory (string): directory to save checkpoints 203 | log_directory (string) directory to save tensorboard logs 204 | checkpoint_path(string): checkpoint path 205 | n_gpus (int): number of gpus 206 | rank (int): rank of current gpu 207 | hparams (object): comma separated list of "name=value" pairs. 208 | """ 209 | 210 | if hparams.distributed_run: 211 | init_distributed(hparams, n_gpus, rank, group_name) 212 | 213 | torch.manual_seed(hparams.seed) 214 | torch.cuda.manual_seed(hparams.seed) 215 | 216 | model = load_model(hparams) 217 | learning_rate = hparams.learning_rate 218 | 219 | parameters_main, parameters_sc = model.grouped_parameters() 220 | 221 | optimizer_main = torch.optim.Adam(parameters_main, lr=learning_rate, 222 | weight_decay=hparams.weight_decay) 223 | optimizer_sc = torch.optim.Adam(parameters_sc, lr=learning_rate, 224 | weight_decay=hparams.weight_decay) 225 | 226 | if hparams.distributed_run: 227 | model = apply_gradient_allreduce(model) 228 | 229 | criterion = ParrotLoss(hparams).cuda() 230 | 231 | logger = prepare_directories_and_logger( 232 | output_directory, log_directory, rank) 233 | 234 | train_loader, valset, collate_fn = prepare_dataloaders(hparams) 235 | 236 | # Load checkpoint if one exists 237 | iteration = 0 238 | epoch_offset = 0 239 | if checkpoint_path is not None: 240 | if warm_start: 241 | model = warm_start_model(checkpoint_path, model) 242 | else: 243 | model, optimizer_main, optimizer_sc, _learning_rate, iteration = load_checkpoint( 244 | checkpoint_path, model, optimizer_main, optimizer_sc) 245 | if hparams.use_saved_learning_rate: 246 | learning_rate = _learning_rate 247 | iteration += 1 # next iteration is iteration + 1 248 | epoch_offset = max(0, int(iteration / len(train_loader))) 249 | 250 | model.train() 251 | # ================ MAIN TRAINNIG LOOP! =================== 252 | for epoch in range(epoch_offset, hparams.epochs): 253 | print(("Epoch: {}".format(epoch))) 254 | 255 | for i, batch in enumerate(train_loader): 256 | 257 | start = time.time() 258 | 259 | for param_group in optimizer_main.param_groups: 260 | param_group['lr'] = learning_rate 261 | 262 | for param_group in optimizer_sc.param_groups: 263 | param_group['lr'] = learning_rate 264 | 265 | 266 | 267 | model.zero_grad() 268 | x, y = model.parse_batch(batch) 269 | 270 | if i % 2 == 0: 271 | y_pred = model(x, True) 272 | losses, acces, l_main, l_sc = criterion(y_pred, y, True) 273 | else: 274 | y_pred = model(x, False) 275 | losses, acces, l_main, l_sc = criterion(y_pred, y, False) 276 | 277 | if hparams.distributed_run: 278 | reduced_losses = [] 279 | for l in losses: 280 | reduced_losses.append(reduce_tensor(l.data, n_gpus).item()) 281 | reduced_acces = [] 282 | for a in acces: 283 | reduced_acces.append(reduce_tensor(a.data, n_gpus).item()) 284 | redl_main = reduce_tensor(l_main.data, n_gpus).item() 285 | redl_sc = reduce_tensor(l_sc.data, n_gpus).item() 286 | else: 287 | reduced_losses = [l.item() for l in losses] 288 | reduced_acces = [a.item() for a in acces] 289 | redl_main = l_main.item() 290 | redl_sc = l_sc.item() 291 | 292 | for p in parameters_sc: 293 | p.requires_grad_(requires_grad=False) 294 | 295 | l_main.backward(retain_graph=True) 296 | grad_norm_main = torch.nn.utils.clip_grad_norm_( 297 | parameters_main, hparams.grad_clip_thresh) 298 | 299 | optimizer_main.step() 300 | 301 | for p in parameters_sc: 302 | p.requires_grad_(requires_grad=True) 303 | for p in parameters_main: 304 | p.requires_grad_(requires_grad=False) 305 | 306 | 307 | l_sc.backward() 308 | grad_norm_sc = torch.nn.utils.clip_grad_norm_( 309 | parameters_sc, hparams.grad_clip_thresh) 310 | 311 | 312 | optimizer_sc.step() 313 | 314 | for p in parameters_main: 315 | p.requires_grad_(requires_grad=True) 316 | 317 | if not math.isnan(redl_main) and rank == 0: 318 | 319 | duration = time.time() - start 320 | task = 'TTS' if i%2 == 0 else 'VC' 321 | print(("Train {} {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format( 322 | task, iteration, redl_main+redl_sc, grad_norm_main, duration))) 323 | logger.log_training( 324 | redl_main+redl_sc, reduced_losses, reduced_acces, grad_norm_main, learning_rate, duration, iteration) 325 | 326 | if (iteration % hparams.iters_per_checkpoint == 0): 327 | validate(model, criterion, valset, iteration, 328 | hparams.batch_size, n_gpus, collate_fn, logger, 329 | hparams.distributed_run, rank) 330 | if rank == 0: 331 | checkpoint_path = os.path.join( 332 | output_directory, "checkpoint_{}".format(iteration)) 333 | save_checkpoint(model, optimizer_main, optimizer_sc, learning_rate, iteration, 334 | checkpoint_path) 335 | 336 | iteration += 1 337 | 338 | 339 | if __name__ == '__main__': 340 | parser = argparse.ArgumentParser() 341 | parser.add_argument('-o', '--output_directory', type=str, 342 | help='directory to save checkpoints') 343 | parser.add_argument('-l', '--log_directory', type=str, 344 | help='directory to save tensorboard logs') 345 | parser.add_argument('-c', '--checkpoint_path', type=str, default=None, 346 | required=False, help='checkpoint path') 347 | parser.add_argument('--warm_start', action='store_true', 348 | help='load the model only (warm start)') 349 | parser.add_argument('--n_gpus', type=int, default=1, 350 | required=False, help='number of gpus') 351 | parser.add_argument('--rank', type=int, default=0, 352 | required=False, help='rank of current gpu') 353 | parser.add_argument('--group_name', type=str, default='group_name', 354 | required=False, help='Distributed group name') 355 | parser.add_argument('--hparams', type=str, 356 | required=False, help='comma separated name=value pairs') 357 | 358 | args = parser.parse_args() 359 | hparams = create_hparams(args.hparams) 360 | 361 | torch.backends.cudnn.enabled = hparams.cudnn_enabled 362 | torch.backends.cudnn.benchmark = hparams.cudnn_benchmark 363 | 364 | print(("Distributed Run:", hparams.distributed_run)) 365 | print(("cuDNN Enabled:", hparams.cudnn_enabled)) 366 | print(("cuDNN Benchmark:", hparams.cudnn_benchmark)) 367 | 368 | train(args.output_directory, args.log_directory, args.checkpoint_path, 369 | args.warm_start, args.n_gpus, args.rank, args.group_name, hparams) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0.1 2 | tensorflow==1.14 3 | numpy 4 | typing 5 | tensorboardX 6 | matplotlib 7 | phonemizer 8 | librosa 9 | Pillow 10 | -------------------------------------------------------------------------------- /struct.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxzhanggg/nonparaSeq2seqVC_code/4c03a6be3bc76207b7cf8222c985dc85c7018cde/struct.PNG --------------------------------------------------------------------------------