├── .gitmodules ├── LICENSE ├── README.md ├── config.json ├── convert_model.py ├── denoiser.py ├── distributed.py ├── glow.py ├── glow_old.py ├── inference.py ├── mel2samp.py ├── requirements.txt ├── train.py └── waveglow_logo.png /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tacotron2"] 2 | path = tacotron2 3 | url = http://github.com/NVIDIA/tacotron2 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, NVIDIA Corporation 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![WaveGlow](waveglow_logo.png "WaveGLow") 2 | 3 | ## WaveGlow: a Flow-based Generative Network for Speech Synthesis 4 | 5 | ### Ryan Prenger, Rafael Valle, and Bryan Catanzaro 6 | 7 | In our recent [paper], we propose WaveGlow: a flow-based network capable of 8 | generating high quality speech from mel-spectrograms. WaveGlow combines insights 9 | from [Glow] and [WaveNet] in order to provide fast, efficient and high-quality 10 | audio synthesis, without the need for auto-regression. WaveGlow is implemented 11 | using only a single network, trained using only a single cost function: 12 | maximizing the likelihood of the training data, which makes the training 13 | procedure simple and stable. 14 | 15 | Our [PyTorch] implementation produces audio samples at a rate of 1200 16 | kHz on an NVIDIA V100 GPU. Mean Opinion Scores show that it delivers audio 17 | quality as good as the best publicly available WaveNet implementation. 18 | 19 | Visit our [website] for audio samples. 20 | 21 | ## Setup 22 | 23 | 1. Clone our repo and initialize submodule 24 | 25 | ```command 26 | git clone https://github.com/NVIDIA/waveglow.git 27 | cd waveglow 28 | git submodule init 29 | git submodule update 30 | ``` 31 | 32 | 2. Install requirements `pip3 install -r requirements.txt` 33 | 34 | 3. Install [Apex] 35 | 36 | 37 | ## Generate audio with our pre-existing model 38 | 39 | 1. Download our [published model] 40 | 2. Download [mel-spectrograms] 41 | 3. Generate audio `python3 inference.py -f <(ls mel_spectrograms/*.pt) -w waveglow_256channels.pt -o . --is_fp16 -s 0.6` 42 | 43 | N.b. use `convert_model.py` to convert your older models to the current model 44 | with fused residual and skip connections. 45 | 46 | ## Train your own model 47 | 48 | 1. Download [LJ Speech Data]. In this example it's in `data/` 49 | 50 | 2. Make a list of the file names to use for training/testing 51 | 52 | ```command 53 | ls data/*.wav | tail -n+10 > train_files.txt 54 | ls data/*.wav | head -n10 > test_files.txt 55 | ``` 56 | 57 | 3. Train your WaveGlow networks 58 | 59 | ```command 60 | mkdir checkpoints 61 | python train.py -c config.json 62 | ``` 63 | 64 | For multi-GPU training replace `train.py` with `distributed.py`. Only tested with single node and NCCL. 65 | 66 | For mixed precision training set `"fp16_run": true` on `config.json`. 67 | 68 | 4. Make test set mel-spectrograms 69 | 70 | `python mel2samp.py -f test_files.txt -o . -c config.json` 71 | 72 | 5. Do inference with your network 73 | 74 | ```command 75 | ls *.pt > mel_files.txt 76 | python3 inference.py -f mel_files.txt -w checkpoints/waveglow_10000 -o . --is_fp16 -s 0.6 77 | ``` 78 | 79 | [//]: # (TODO) 80 | [//]: # (PROVIDE INSTRUCTIONS FOR DOWNLOADING LJS) 81 | [pytorch 1.0]: https://github.com/pytorch/pytorch#installation 82 | [website]: https://nv-adlr.github.io/WaveGlow 83 | [paper]: https://arxiv.org/abs/1811.00002 84 | [WaveNet implementation]: https://github.com/r9y9/wavenet_vocoder 85 | [Glow]: https://blog.openai.com/glow/ 86 | [WaveNet]: https://deepmind.com/blog/wavenet-generative-model-raw-audio/ 87 | [PyTorch]: http://pytorch.org 88 | [published model]: https://drive.google.com/open?id=1rpK8CzAAirq9sWZhe9nlfvxMF1dRgFbF 89 | [mel-spectrograms]: https://drive.google.com/file/d/1g_VXK2lpP9J25dQFhQwx7doWl_p20fXA/view?usp=sharing 90 | [LJ Speech Data]: https://keithito.com/LJ-Speech-Dataset 91 | [Apex]: https://github.com/nvidia/apex 92 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_config": { 3 | "fp16_run": true, 4 | "output_directory": "checkpoints", 5 | "epochs": 100000, 6 | "learning_rate": 1e-4, 7 | "sigma": 1.0, 8 | "iters_per_checkpoint": 2000, 9 | "batch_size": 12, 10 | "seed": 1234, 11 | "checkpoint_path": "", 12 | "with_tensorboard": false 13 | }, 14 | "data_config": { 15 | "training_files": "train_files.txt", 16 | "segment_length": 16000, 17 | "sampling_rate": 22050, 18 | "filter_length": 1024, 19 | "hop_length": 256, 20 | "win_length": 1024, 21 | "mel_fmin": 0.0, 22 | "mel_fmax": 8000.0 23 | }, 24 | "dist_config": { 25 | "dist_backend": "nccl", 26 | "dist_url": "tcp://localhost:54321" 27 | }, 28 | 29 | "waveglow_config": { 30 | "n_mel_channels": 80, 31 | "n_flows": 12, 32 | "n_group": 8, 33 | "n_early_every": 4, 34 | "n_early_size": 2, 35 | "WN_config": { 36 | "n_layers": 8, 37 | "n_channels": 256, 38 | "kernel_size": 3 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /convert_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import torch 4 | 5 | def _check_model_old_version(model): 6 | if hasattr(model.WN[0], 'res_layers') or hasattr(model.WN[0], 'cond_layers'): 7 | return True 8 | else: 9 | return False 10 | 11 | 12 | def _update_model_res_skip(old_model, new_model): 13 | for idx in range(0, len(new_model.WN)): 14 | wavenet = new_model.WN[idx] 15 | n_channels = wavenet.n_channels 16 | n_layers = wavenet.n_layers 17 | wavenet.res_skip_layers = torch.nn.ModuleList() 18 | for i in range(0, n_layers): 19 | if i < n_layers - 1: 20 | res_skip_channels = 2*n_channels 21 | else: 22 | res_skip_channels = n_channels 23 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) 24 | skip_layer = torch.nn.utils.remove_weight_norm(wavenet.skip_layers[i]) 25 | if i < n_layers - 1: 26 | res_layer = torch.nn.utils.remove_weight_norm(wavenet.res_layers[i]) 27 | res_skip_layer.weight = torch.nn.Parameter(torch.cat([res_layer.weight, skip_layer.weight])) 28 | res_skip_layer.bias = torch.nn.Parameter(torch.cat([res_layer.bias, skip_layer.bias])) 29 | else: 30 | res_skip_layer.weight = torch.nn.Parameter(skip_layer.weight) 31 | res_skip_layer.bias = torch.nn.Parameter(skip_layer.bias) 32 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 33 | wavenet.res_skip_layers.append(res_skip_layer) 34 | del wavenet.res_layers 35 | del wavenet.skip_layers 36 | 37 | def _update_model_cond(old_model, new_model): 38 | for idx in range(0, len(new_model.WN)): 39 | wavenet = new_model.WN[idx] 40 | n_channels = wavenet.n_channels 41 | n_layers = wavenet.n_layers 42 | n_mel_channels = wavenet.cond_layers[0].weight.shape[1] 43 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1) 44 | cond_layer_weight = [] 45 | cond_layer_bias = [] 46 | for i in range(0, n_layers): 47 | _cond_layer = torch.nn.utils.remove_weight_norm(wavenet.cond_layers[i]) 48 | cond_layer_weight.append(_cond_layer.weight) 49 | cond_layer_bias.append(_cond_layer.bias) 50 | cond_layer.weight = torch.nn.Parameter(torch.cat(cond_layer_weight)) 51 | cond_layer.bias = torch.nn.Parameter(torch.cat(cond_layer_bias)) 52 | cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 53 | wavenet.cond_layer = cond_layer 54 | del wavenet.cond_layers 55 | 56 | def update_model(old_model): 57 | if not _check_model_old_version(old_model): 58 | return old_model 59 | new_model = copy.deepcopy(old_model) 60 | if hasattr(old_model.WN[0], 'res_layers'): 61 | _update_model_res_skip(old_model, new_model) 62 | if hasattr(old_model.WN[0], 'cond_layers'): 63 | _update_model_cond(old_model, new_model) 64 | for m in new_model.modules(): 65 | if 'Conv' in str(type(m)) and not hasattr(m, 'padding_mode'): 66 | setattr(m, 'padding_mode', 'zeros') 67 | return new_model 68 | 69 | if __name__ == '__main__': 70 | old_model_path = sys.argv[1] 71 | new_model_path = sys.argv[2] 72 | model = torch.load(old_model_path, map_location='cpu') 73 | model['model'] = update_model(model['model']) 74 | torch.save(model, new_model_path) 75 | 76 | -------------------------------------------------------------------------------- /denoiser.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('tacotron2') 3 | import torch 4 | from layers import STFT 5 | 6 | 7 | class Denoiser(torch.nn.Module): 8 | """ Removes model bias from audio produced with waveglow """ 9 | 10 | def __init__(self, waveglow, filter_length=1024, n_overlap=4, 11 | win_length=1024, mode='zeros'): 12 | super(Denoiser, self).__init__() 13 | self.stft = STFT(filter_length=filter_length, 14 | hop_length=int(filter_length/n_overlap), 15 | win_length=win_length).cuda() 16 | if mode == 'zeros': 17 | mel_input = torch.zeros( 18 | (1, 80, 88), 19 | dtype=waveglow.upsample.weight.dtype, 20 | device=waveglow.upsample.weight.device) 21 | elif mode == 'normal': 22 | mel_input = torch.randn( 23 | (1, 80, 88), 24 | dtype=waveglow.upsample.weight.dtype, 25 | device=waveglow.upsample.weight.device) 26 | else: 27 | raise Exception("Mode {} if not supported".format(mode)) 28 | 29 | with torch.no_grad(): 30 | bias_audio = waveglow.infer(mel_input, sigma=0.0).float() 31 | bias_spec, _ = self.stft.transform(bias_audio) 32 | 33 | self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None]) 34 | 35 | def forward(self, audio, strength=0.1): 36 | audio_spec, audio_angles = self.stft.transform(audio.cuda().float()) 37 | audio_spec_denoised = audio_spec - self.bias_spec * strength 38 | audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) 39 | audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles) 40 | return audio_denoised 41 | -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import os 28 | import sys 29 | import time 30 | import subprocess 31 | import argparse 32 | 33 | import torch 34 | import torch.distributed as dist 35 | from torch.autograd import Variable 36 | 37 | def reduce_tensor(tensor, num_gpus): 38 | rt = tensor.clone() 39 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 40 | rt /= num_gpus 41 | return rt 42 | 43 | def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): 44 | assert torch.cuda.is_available(), "Distributed mode requires CUDA." 45 | print("Initializing Distributed") 46 | 47 | # Set cuda device so everything is done on the right GPU. 48 | torch.cuda.set_device(rank % torch.cuda.device_count()) 49 | 50 | # Initialize distributed communication 51 | dist.init_process_group(dist_backend, init_method=dist_url, 52 | world_size=num_gpus, rank=rank, 53 | group_name=group_name) 54 | 55 | def _flatten_dense_tensors(tensors): 56 | """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of 57 | same dense type. 58 | Since inputs are dense, the resulting tensor will be a concatenated 1D 59 | buffer. Element-wise operation on this buffer will be equivalent to 60 | operating individually. 61 | Arguments: 62 | tensors (Iterable[Tensor]): dense tensors to flatten. 63 | Returns: 64 | A contiguous 1D buffer containing input tensors. 65 | """ 66 | if len(tensors) == 1: 67 | return tensors[0].contiguous().view(-1) 68 | flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) 69 | return flat 70 | 71 | def _unflatten_dense_tensors(flat, tensors): 72 | """View a flat buffer using the sizes of tensors. Assume that tensors are of 73 | same dense type, and that flat is given by _flatten_dense_tensors. 74 | Arguments: 75 | flat (Tensor): flattened dense tensors to unflatten. 76 | tensors (Iterable[Tensor]): dense tensors whose sizes will be used to 77 | unflatten flat. 78 | Returns: 79 | Unflattened dense tensors with sizes same as tensors and values from 80 | flat. 81 | """ 82 | outputs = [] 83 | offset = 0 84 | for tensor in tensors: 85 | numel = tensor.numel() 86 | outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) 87 | offset += numel 88 | return tuple(outputs) 89 | 90 | def apply_gradient_allreduce(module): 91 | """ 92 | Modifies existing model to do gradient allreduce, but doesn't change class 93 | so you don't need "module" 94 | """ 95 | if not hasattr(dist, '_backend'): 96 | module.warn_on_half = True 97 | else: 98 | module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 99 | 100 | for p in module.state_dict().values(): 101 | if not torch.is_tensor(p): 102 | continue 103 | dist.broadcast(p, 0) 104 | 105 | def allreduce_params(): 106 | if(module.needs_reduction): 107 | module.needs_reduction = False 108 | buckets = {} 109 | for param in module.parameters(): 110 | if param.requires_grad and param.grad is not None: 111 | tp = type(param.data) 112 | if tp not in buckets: 113 | buckets[tp] = [] 114 | buckets[tp].append(param) 115 | if module.warn_on_half: 116 | if torch.cuda.HalfTensor in buckets: 117 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 118 | " It is recommended to use the NCCL backend in this case. This currently requires" + 119 | "PyTorch built from top of tree master.") 120 | module.warn_on_half = False 121 | 122 | for tp in buckets: 123 | bucket = buckets[tp] 124 | grads = [param.grad.data for param in bucket] 125 | coalesced = _flatten_dense_tensors(grads) 126 | dist.all_reduce(coalesced) 127 | coalesced /= dist.get_world_size() 128 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 129 | buf.copy_(synced) 130 | 131 | for param in list(module.parameters()): 132 | def allreduce_hook(*unused): 133 | Variable._execution_engine.queue_callback(allreduce_params) 134 | if param.requires_grad: 135 | param.register_hook(allreduce_hook) 136 | dir(param) 137 | 138 | def set_needs_reduction(self, input, output): 139 | self.needs_reduction = True 140 | 141 | module.register_forward_hook(set_needs_reduction) 142 | return module 143 | 144 | 145 | def main(config, stdout_dir, args_str): 146 | args_list = ['train.py'] 147 | args_list += args_str.split(' ') if len(args_str) > 0 else [] 148 | 149 | args_list.append('--config={}'.format(config)) 150 | 151 | num_gpus = torch.cuda.device_count() 152 | args_list.append('--num_gpus={}'.format(num_gpus)) 153 | args_list.append("--group_name=group_{}".format(time.strftime("%Y_%m_%d-%H%M%S"))) 154 | 155 | if not os.path.isdir(stdout_dir): 156 | os.makedirs(stdout_dir) 157 | os.chmod(stdout_dir, 0o775) 158 | 159 | workers = [] 160 | 161 | for i in range(num_gpus): 162 | args_list[-2] = '--rank={}'.format(i) 163 | stdout = None if i == 0 else open( 164 | os.path.join(stdout_dir, "GPU_{}.log".format(i)), "w") 165 | print(args_list) 166 | p = subprocess.Popen([str(sys.executable)]+args_list, stdout=stdout) 167 | workers.append(p) 168 | 169 | for p in workers: 170 | p.wait() 171 | 172 | 173 | if __name__ == '__main__': 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument('-c', '--config', type=str, required=True, 176 | help='JSON file for configuration') 177 | parser.add_argument('-s', '--stdout_dir', type=str, default=".", 178 | help='directory to save stoud logs') 179 | parser.add_argument( 180 | '-a', '--args_str', type=str, default='', 181 | help='double quoted string with space separated key value pairs') 182 | 183 | args = parser.parse_args() 184 | main(args.config, args.stdout_dir, args.args_str) 185 | -------------------------------------------------------------------------------- /glow.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import copy 28 | import torch 29 | from torch.autograd import Variable 30 | import torch.nn.functional as F 31 | 32 | 33 | @torch.jit.script 34 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 35 | n_channels_int = n_channels[0] 36 | in_act = input_a+input_b 37 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 38 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 39 | acts = t_act * s_act 40 | return acts 41 | 42 | 43 | class WaveGlowLoss(torch.nn.Module): 44 | def __init__(self, sigma=1.0): 45 | super(WaveGlowLoss, self).__init__() 46 | self.sigma = sigma 47 | 48 | def forward(self, model_output): 49 | z, log_s_list, log_det_W_list = model_output 50 | for i, log_s in enumerate(log_s_list): 51 | if i == 0: 52 | log_s_total = torch.sum(log_s) 53 | log_det_W_total = log_det_W_list[i] 54 | else: 55 | log_s_total = log_s_total + torch.sum(log_s) 56 | log_det_W_total += log_det_W_list[i] 57 | 58 | loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total 59 | return loss/(z.size(0)*z.size(1)*z.size(2)) 60 | 61 | 62 | class Invertible1x1Conv(torch.nn.Module): 63 | """ 64 | The layer outputs both the convolution, and the log determinant 65 | of its weight matrix. If reverse=True it does convolution with 66 | inverse 67 | """ 68 | def __init__(self, c): 69 | super(Invertible1x1Conv, self).__init__() 70 | self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, 71 | bias=False) 72 | 73 | # Sample a random orthonormal matrix to initialize weights 74 | W = torch.qr(torch.FloatTensor(c, c).normal_())[0] 75 | 76 | # Ensure determinant is 1.0 not -1.0 77 | if torch.det(W) < 0: 78 | W[:,0] = -1*W[:,0] 79 | W = W.view(c, c, 1) 80 | self.conv.weight.data = W 81 | 82 | def forward(self, z, reverse=False): 83 | # shape 84 | batch_size, group_size, n_of_groups = z.size() 85 | 86 | W = self.conv.weight.squeeze() 87 | 88 | if reverse: 89 | if not hasattr(self, 'W_inverse'): 90 | # Reverse computation 91 | W_inverse = W.float().inverse() 92 | W_inverse = Variable(W_inverse[..., None]) 93 | if z.type() == 'torch.cuda.HalfTensor': 94 | W_inverse = W_inverse.half() 95 | self.W_inverse = W_inverse 96 | z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) 97 | return z 98 | else: 99 | # Forward computation 100 | log_det_W = batch_size * n_of_groups * torch.logdet(W) 101 | z = self.conv(z) 102 | return z, log_det_W 103 | 104 | 105 | class WN(torch.nn.Module): 106 | """ 107 | This is the WaveNet like layer for the affine coupling. The primary difference 108 | from WaveNet is the convolutions need not be causal. There is also no dilation 109 | size reset. The dilation only doubles on each layer 110 | """ 111 | def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels, 112 | kernel_size): 113 | super(WN, self).__init__() 114 | assert(kernel_size % 2 == 1) 115 | assert(n_channels % 2 == 0) 116 | self.n_layers = n_layers 117 | self.n_channels = n_channels 118 | self.in_layers = torch.nn.ModuleList() 119 | self.res_skip_layers = torch.nn.ModuleList() 120 | 121 | start = torch.nn.Conv1d(n_in_channels, n_channels, 1) 122 | start = torch.nn.utils.weight_norm(start, name='weight') 123 | self.start = start 124 | 125 | # Initializing last layer to 0 makes the affine coupling layers 126 | # do nothing at first. This helps with training stability 127 | end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1) 128 | end.weight.data.zero_() 129 | end.bias.data.zero_() 130 | self.end = end 131 | 132 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1) 133 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 134 | 135 | for i in range(n_layers): 136 | dilation = 2 ** i 137 | padding = int((kernel_size*dilation - dilation)/2) 138 | in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size, 139 | dilation=dilation, padding=padding) 140 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 141 | self.in_layers.append(in_layer) 142 | 143 | 144 | # last one is not necessary 145 | if i < n_layers - 1: 146 | res_skip_channels = 2*n_channels 147 | else: 148 | res_skip_channels = n_channels 149 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) 150 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 151 | self.res_skip_layers.append(res_skip_layer) 152 | 153 | def forward(self, forward_input): 154 | audio, spect = forward_input 155 | audio = self.start(audio) 156 | output = torch.zeros_like(audio) 157 | n_channels_tensor = torch.IntTensor([self.n_channels]) 158 | 159 | spect = self.cond_layer(spect) 160 | 161 | for i in range(self.n_layers): 162 | spect_offset = i*2*self.n_channels 163 | acts = fused_add_tanh_sigmoid_multiply( 164 | self.in_layers[i](audio), 165 | spect[:,spect_offset:spect_offset+2*self.n_channels,:], 166 | n_channels_tensor) 167 | 168 | res_skip_acts = self.res_skip_layers[i](acts) 169 | if i < self.n_layers - 1: 170 | audio = audio + res_skip_acts[:,:self.n_channels,:] 171 | output = output + res_skip_acts[:,self.n_channels:,:] 172 | else: 173 | output = output + res_skip_acts 174 | 175 | return self.end(output) 176 | 177 | 178 | class WaveGlow(torch.nn.Module): 179 | def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, 180 | n_early_size, WN_config): 181 | super(WaveGlow, self).__init__() 182 | 183 | self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, 184 | n_mel_channels, 185 | 1024, stride=256) 186 | assert(n_group % 2 == 0) 187 | self.n_flows = n_flows 188 | self.n_group = n_group 189 | self.n_early_every = n_early_every 190 | self.n_early_size = n_early_size 191 | self.WN = torch.nn.ModuleList() 192 | self.convinv = torch.nn.ModuleList() 193 | 194 | n_half = int(n_group/2) 195 | 196 | # Set up layers with the right sizes based on how many dimensions 197 | # have been output already 198 | n_remaining_channels = n_group 199 | for k in range(n_flows): 200 | if k % self.n_early_every == 0 and k > 0: 201 | n_half = n_half - int(self.n_early_size/2) 202 | n_remaining_channels = n_remaining_channels - self.n_early_size 203 | self.convinv.append(Invertible1x1Conv(n_remaining_channels)) 204 | self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config)) 205 | self.n_remaining_channels = n_remaining_channels # Useful during inference 206 | 207 | def forward(self, forward_input): 208 | """ 209 | forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames 210 | forward_input[1] = audio: batch x time 211 | """ 212 | spect, audio = forward_input 213 | 214 | # Upsample spectrogram to size of audio 215 | spect = self.upsample(spect) 216 | assert(spect.size(2) >= audio.size(1)) 217 | if spect.size(2) > audio.size(1): 218 | spect = spect[:, :, :audio.size(1)] 219 | 220 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 221 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 222 | 223 | audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) 224 | output_audio = [] 225 | log_s_list = [] 226 | log_det_W_list = [] 227 | 228 | for k in range(self.n_flows): 229 | if k % self.n_early_every == 0 and k > 0: 230 | output_audio.append(audio[:,:self.n_early_size,:]) 231 | audio = audio[:,self.n_early_size:,:] 232 | 233 | audio, log_det_W = self.convinv[k](audio) 234 | log_det_W_list.append(log_det_W) 235 | 236 | n_half = int(audio.size(1)/2) 237 | audio_0 = audio[:,:n_half,:] 238 | audio_1 = audio[:,n_half:,:] 239 | 240 | output = self.WN[k]((audio_0, spect)) 241 | log_s = output[:, n_half:, :] 242 | b = output[:, :n_half, :] 243 | audio_1 = torch.exp(log_s)*audio_1 + b 244 | log_s_list.append(log_s) 245 | 246 | audio = torch.cat([audio_0, audio_1],1) 247 | 248 | output_audio.append(audio) 249 | return torch.cat(output_audio,1), log_s_list, log_det_W_list 250 | 251 | def infer(self, spect, sigma=1.0): 252 | spect = self.upsample(spect) 253 | # trim conv artifacts. maybe pad spec to kernel multiple 254 | time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0] 255 | spect = spect[:, :, :-time_cutoff] 256 | 257 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 258 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 259 | 260 | if spect.type() == 'torch.cuda.HalfTensor': 261 | audio = torch.cuda.HalfTensor(spect.size(0), 262 | self.n_remaining_channels, 263 | spect.size(2)).normal_() 264 | else: 265 | audio = torch.cuda.FloatTensor(spect.size(0), 266 | self.n_remaining_channels, 267 | spect.size(2)).normal_() 268 | 269 | audio = torch.autograd.Variable(sigma*audio) 270 | 271 | for k in reversed(range(self.n_flows)): 272 | n_half = int(audio.size(1)/2) 273 | audio_0 = audio[:,:n_half,:] 274 | audio_1 = audio[:,n_half:,:] 275 | 276 | output = self.WN[k]((audio_0, spect)) 277 | 278 | s = output[:, n_half:, :] 279 | b = output[:, :n_half, :] 280 | audio_1 = (audio_1 - b)/torch.exp(s) 281 | audio = torch.cat([audio_0, audio_1],1) 282 | 283 | audio = self.convinv[k](audio, reverse=True) 284 | 285 | if k % self.n_early_every == 0 and k > 0: 286 | if spect.type() == 'torch.cuda.HalfTensor': 287 | z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() 288 | else: 289 | z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() 290 | audio = torch.cat((sigma*z, audio),1) 291 | 292 | audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data 293 | return audio 294 | 295 | @staticmethod 296 | def remove_weightnorm(model): 297 | waveglow = model 298 | for WN in waveglow.WN: 299 | WN.start = torch.nn.utils.remove_weight_norm(WN.start) 300 | WN.in_layers = remove(WN.in_layers) 301 | WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer) 302 | WN.res_skip_layers = remove(WN.res_skip_layers) 303 | return waveglow 304 | 305 | 306 | def remove(conv_list): 307 | new_conv_list = torch.nn.ModuleList() 308 | for old_conv in conv_list: 309 | old_conv = torch.nn.utils.remove_weight_norm(old_conv) 310 | new_conv_list.append(old_conv) 311 | return new_conv_list 312 | -------------------------------------------------------------------------------- /glow_old.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from glow import Invertible1x1Conv, remove 4 | 5 | 6 | @torch.jit.script 7 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 8 | n_channels_int = n_channels[0] 9 | in_act = input_a+input_b 10 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 11 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 12 | acts = t_act * s_act 13 | return acts 14 | 15 | 16 | class WN(torch.nn.Module): 17 | """ 18 | This is the WaveNet like layer for the affine coupling. The primary difference 19 | from WaveNet is the convolutions need not be causal. There is also no dilation 20 | size reset. The dilation only doubles on each layer 21 | """ 22 | def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels, 23 | kernel_size): 24 | super(WN, self).__init__() 25 | assert(kernel_size % 2 == 1) 26 | assert(n_channels % 2 == 0) 27 | self.n_layers = n_layers 28 | self.n_channels = n_channels 29 | self.in_layers = torch.nn.ModuleList() 30 | self.res_skip_layers = torch.nn.ModuleList() 31 | self.cond_layers = torch.nn.ModuleList() 32 | 33 | start = torch.nn.Conv1d(n_in_channels, n_channels, 1) 34 | start = torch.nn.utils.weight_norm(start, name='weight') 35 | self.start = start 36 | 37 | # Initializing last layer to 0 makes the affine coupling layers 38 | # do nothing at first. This helps with training stability 39 | end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1) 40 | end.weight.data.zero_() 41 | end.bias.data.zero_() 42 | self.end = end 43 | 44 | for i in range(n_layers): 45 | dilation = 2 ** i 46 | padding = int((kernel_size*dilation - dilation)/2) 47 | in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size, 48 | dilation=dilation, padding=padding) 49 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 50 | self.in_layers.append(in_layer) 51 | 52 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1) 53 | cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 54 | self.cond_layers.append(cond_layer) 55 | 56 | # last one is not necessary 57 | if i < n_layers - 1: 58 | res_skip_channels = 2*n_channels 59 | else: 60 | res_skip_channels = n_channels 61 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) 62 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 63 | self.res_skip_layers.append(res_skip_layer) 64 | 65 | def forward(self, forward_input): 66 | audio, spect = forward_input 67 | audio = self.start(audio) 68 | 69 | for i in range(self.n_layers): 70 | acts = fused_add_tanh_sigmoid_multiply( 71 | self.in_layers[i](audio), 72 | self.cond_layers[i](spect), 73 | torch.IntTensor([self.n_channels])) 74 | 75 | res_skip_acts = self.res_skip_layers[i](acts) 76 | if i < self.n_layers - 1: 77 | audio = res_skip_acts[:,:self.n_channels,:] + audio 78 | skip_acts = res_skip_acts[:,self.n_channels:,:] 79 | else: 80 | skip_acts = res_skip_acts 81 | 82 | if i == 0: 83 | output = skip_acts 84 | else: 85 | output = skip_acts + output 86 | return self.end(output) 87 | 88 | 89 | class WaveGlow(torch.nn.Module): 90 | def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, 91 | n_early_size, WN_config): 92 | super(WaveGlow, self).__init__() 93 | 94 | self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, 95 | n_mel_channels, 96 | 1024, stride=256) 97 | assert(n_group % 2 == 0) 98 | self.n_flows = n_flows 99 | self.n_group = n_group 100 | self.n_early_every = n_early_every 101 | self.n_early_size = n_early_size 102 | self.WN = torch.nn.ModuleList() 103 | self.convinv = torch.nn.ModuleList() 104 | 105 | n_half = int(n_group/2) 106 | 107 | # Set up layers with the right sizes based on how many dimensions 108 | # have been output already 109 | n_remaining_channels = n_group 110 | for k in range(n_flows): 111 | if k % self.n_early_every == 0 and k > 0: 112 | n_half = n_half - int(self.n_early_size/2) 113 | n_remaining_channels = n_remaining_channels - self.n_early_size 114 | self.convinv.append(Invertible1x1Conv(n_remaining_channels)) 115 | self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config)) 116 | self.n_remaining_channels = n_remaining_channels # Useful during inference 117 | 118 | def forward(self, forward_input): 119 | return None 120 | """ 121 | forward_input[0] = audio: batch x time 122 | forward_input[1] = upsamp_spectrogram: batch x n_cond_channels x time 123 | """ 124 | """ 125 | spect, audio = forward_input 126 | 127 | # Upsample spectrogram to size of audio 128 | spect = self.upsample(spect) 129 | assert(spect.size(2) >= audio.size(1)) 130 | if spect.size(2) > audio.size(1): 131 | spect = spect[:, :, :audio.size(1)] 132 | 133 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 134 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 135 | 136 | audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) 137 | output_audio = [] 138 | s_list = [] 139 | s_conv_list = [] 140 | 141 | for k in range(self.n_flows): 142 | if k%4 == 0 and k > 0: 143 | output_audio.append(audio[:,:self.n_multi,:]) 144 | audio = audio[:,self.n_multi:,:] 145 | 146 | # project to new basis 147 | audio, s = self.convinv[k](audio) 148 | s_conv_list.append(s) 149 | 150 | n_half = int(audio.size(1)/2) 151 | if k%2 == 0: 152 | audio_0 = audio[:,:n_half,:] 153 | audio_1 = audio[:,n_half:,:] 154 | else: 155 | audio_1 = audio[:,:n_half,:] 156 | audio_0 = audio[:,n_half:,:] 157 | 158 | output = self.nn[k]((audio_0, spect)) 159 | s = output[:, n_half:, :] 160 | b = output[:, :n_half, :] 161 | audio_1 = torch.exp(s)*audio_1 + b 162 | s_list.append(s) 163 | 164 | if k%2 == 0: 165 | audio = torch.cat([audio[:,:n_half,:], audio_1],1) 166 | else: 167 | audio = torch.cat([audio_1, audio[:,n_half:,:]], 1) 168 | output_audio.append(audio) 169 | return torch.cat(output_audio,1), s_list, s_conv_list 170 | """ 171 | 172 | def infer(self, spect, sigma=1.0): 173 | spect = self.upsample(spect) 174 | # trim conv artifacts. maybe pad spec to kernel multiple 175 | time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0] 176 | spect = spect[:, :, :-time_cutoff] 177 | 178 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 179 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 180 | 181 | if spect.type() == 'torch.cuda.HalfTensor': 182 | audio = torch.cuda.HalfTensor(spect.size(0), 183 | self.n_remaining_channels, 184 | spect.size(2)).normal_() 185 | else: 186 | audio = torch.cuda.FloatTensor(spect.size(0), 187 | self.n_remaining_channels, 188 | spect.size(2)).normal_() 189 | 190 | audio = torch.autograd.Variable(sigma*audio) 191 | 192 | for k in reversed(range(self.n_flows)): 193 | n_half = int(audio.size(1)/2) 194 | if k%2 == 0: 195 | audio_0 = audio[:,:n_half,:] 196 | audio_1 = audio[:,n_half:,:] 197 | else: 198 | audio_1 = audio[:,:n_half,:] 199 | audio_0 = audio[:,n_half:,:] 200 | 201 | output = self.WN[k]((audio_0, spect)) 202 | s = output[:, n_half:, :] 203 | b = output[:, :n_half, :] 204 | audio_1 = (audio_1 - b)/torch.exp(s) 205 | if k%2 == 0: 206 | audio = torch.cat([audio[:,:n_half,:], audio_1],1) 207 | else: 208 | audio = torch.cat([audio_1, audio[:,n_half:,:]], 1) 209 | 210 | audio = self.convinv[k](audio, reverse=True) 211 | 212 | if k%4 == 0 and k > 0: 213 | if spect.type() == 'torch.cuda.HalfTensor': 214 | z = torch.cuda.HalfTensor(spect.size(0), 215 | self.n_early_size, 216 | spect.size(2)).normal_() 217 | else: 218 | z = torch.cuda.FloatTensor(spect.size(0), 219 | self.n_early_size, 220 | spect.size(2)).normal_() 221 | audio = torch.cat((sigma*z, audio),1) 222 | 223 | return audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data 224 | 225 | @staticmethod 226 | def remove_weightnorm(model): 227 | waveglow = model 228 | for WN in waveglow.WN: 229 | WN.start = torch.nn.utils.remove_weight_norm(WN.start) 230 | WN.in_layers = remove(WN.in_layers) 231 | WN.cond_layers = remove(WN.cond_layers) 232 | WN.res_skip_layers = remove(WN.res_skip_layers) 233 | return waveglow 234 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | # ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import os 28 | from scipy.io.wavfile import write 29 | import torch 30 | from mel2samp import files_to_list, MAX_WAV_VALUE 31 | from denoiser import Denoiser 32 | 33 | 34 | def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16, 35 | denoiser_strength): 36 | mel_files = files_to_list(mel_files) 37 | waveglow = torch.load(waveglow_path)['model'] 38 | waveglow = waveglow.remove_weightnorm(waveglow) 39 | waveglow.cuda().eval() 40 | if is_fp16: 41 | from apex import amp 42 | waveglow, _ = amp.initialize(waveglow, [], opt_level="O3") 43 | 44 | if denoiser_strength > 0: 45 | denoiser = Denoiser(waveglow).cuda() 46 | 47 | for i, file_path in enumerate(mel_files): 48 | file_name = os.path.splitext(os.path.basename(file_path))[0] 49 | mel = torch.load(file_path) 50 | mel = torch.autograd.Variable(mel.cuda()) 51 | mel = torch.unsqueeze(mel, 0) 52 | mel = mel.half() if is_fp16 else mel 53 | with torch.no_grad(): 54 | audio = waveglow.infer(mel, sigma=sigma) 55 | if denoiser_strength > 0: 56 | audio = denoiser(audio, denoiser_strength) 57 | audio = audio * MAX_WAV_VALUE 58 | audio = audio.squeeze() 59 | audio = audio.cpu().numpy() 60 | audio = audio.astype('int16') 61 | audio_path = os.path.join( 62 | output_dir, "{}_synthesis.wav".format(file_name)) 63 | write(audio_path, sampling_rate, audio) 64 | print(audio_path) 65 | 66 | 67 | if __name__ == "__main__": 68 | import argparse 69 | 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('-f', "--filelist_path", required=True) 72 | parser.add_argument('-w', '--waveglow_path', 73 | help='Path to waveglow decoder checkpoint with model') 74 | parser.add_argument('-o', "--output_dir", required=True) 75 | parser.add_argument("-s", "--sigma", default=1.0, type=float) 76 | parser.add_argument("--sampling_rate", default=22050, type=int) 77 | parser.add_argument("--is_fp16", action="store_true") 78 | parser.add_argument("-d", "--denoiser_strength", default=0.0, type=float, 79 | help='Removes model bias. Start with 0.1 and adjust') 80 | 81 | args = parser.parse_args() 82 | 83 | main(args.filelist_path, args.waveglow_path, args.sigma, args.output_dir, 84 | args.sampling_rate, args.is_fp16, args.denoiser_strength) 85 | -------------------------------------------------------------------------------- /mel2samp.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # *****************************************************************************\ 27 | import os 28 | import random 29 | import argparse 30 | import json 31 | import torch 32 | import torch.utils.data 33 | import sys 34 | from scipy.io.wavfile import read 35 | 36 | # We're using the audio processing from TacoTron2 to make sure it matches 37 | sys.path.insert(0, 'tacotron2') 38 | from tacotron2.layers import TacotronSTFT 39 | 40 | MAX_WAV_VALUE = 32768.0 41 | 42 | def files_to_list(filename): 43 | """ 44 | Takes a text file of filenames and makes a list of filenames 45 | """ 46 | with open(filename, encoding='utf-8') as f: 47 | files = f.readlines() 48 | 49 | files = [f.rstrip() for f in files] 50 | return files 51 | 52 | def load_wav_to_torch(full_path): 53 | """ 54 | Loads wavdata into torch array 55 | """ 56 | sampling_rate, data = read(full_path) 57 | return torch.from_numpy(data).float(), sampling_rate 58 | 59 | 60 | class Mel2Samp(torch.utils.data.Dataset): 61 | """ 62 | This is the main class that calculates the spectrogram and returns the 63 | spectrogram, audio pair. 64 | """ 65 | def __init__(self, training_files, segment_length, filter_length, 66 | hop_length, win_length, sampling_rate, mel_fmin, mel_fmax): 67 | self.audio_files = files_to_list(training_files) 68 | random.seed(1234) 69 | random.shuffle(self.audio_files) 70 | self.stft = TacotronSTFT(filter_length=filter_length, 71 | hop_length=hop_length, 72 | win_length=win_length, 73 | sampling_rate=sampling_rate, 74 | mel_fmin=mel_fmin, mel_fmax=mel_fmax) 75 | self.segment_length = segment_length 76 | self.sampling_rate = sampling_rate 77 | 78 | def get_mel(self, audio): 79 | audio_norm = audio / MAX_WAV_VALUE 80 | audio_norm = audio_norm.unsqueeze(0) 81 | audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) 82 | melspec = self.stft.mel_spectrogram(audio_norm) 83 | melspec = torch.squeeze(melspec, 0) 84 | return melspec 85 | 86 | def __getitem__(self, index): 87 | # Read audio 88 | filename = self.audio_files[index] 89 | audio, sampling_rate = load_wav_to_torch(filename) 90 | if sampling_rate != self.sampling_rate: 91 | raise ValueError("{} SR doesn't match target {} SR".format( 92 | sampling_rate, self.sampling_rate)) 93 | 94 | # Take segment 95 | if audio.size(0) >= self.segment_length: 96 | max_audio_start = audio.size(0) - self.segment_length 97 | audio_start = random.randint(0, max_audio_start) 98 | audio = audio[audio_start:audio_start+self.segment_length] 99 | else: 100 | audio = torch.nn.functional.pad(audio, (0, self.segment_length - audio.size(0)), 'constant').data 101 | 102 | mel = self.get_mel(audio) 103 | audio = audio / MAX_WAV_VALUE 104 | 105 | return (mel, audio) 106 | 107 | def __len__(self): 108 | return len(self.audio_files) 109 | 110 | # =================================================================== 111 | # Takes directory of clean audio and makes directory of spectrograms 112 | # Useful for making test sets 113 | # =================================================================== 114 | if __name__ == "__main__": 115 | # Get defaults so it can work with no Sacred 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('-f', "--filelist_path", required=True) 118 | parser.add_argument('-c', '--config', type=str, 119 | help='JSON file for configuration') 120 | parser.add_argument('-o', '--output_dir', type=str, 121 | help='Output directory') 122 | args = parser.parse_args() 123 | 124 | with open(args.config) as f: 125 | data = f.read() 126 | data_config = json.loads(data)["data_config"] 127 | mel2samp = Mel2Samp(**data_config) 128 | 129 | filepaths = files_to_list(args.filelist_path) 130 | 131 | # Make directory if it doesn't exist 132 | if not os.path.isdir(args.output_dir): 133 | os.makedirs(args.output_dir) 134 | os.chmod(args.output_dir, 0o775) 135 | 136 | for filepath in filepaths: 137 | audio, sr = load_wav_to_torch(filepath) 138 | melspectrogram = mel2samp.get_mel(audio) 139 | filename = os.path.basename(filepath) 140 | new_filepath = args.output_dir + '/' + filename + '.pt' 141 | print(new_filepath) 142 | torch.save(melspectrogram, new_filepath) 143 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0 2 | matplotlib==2.1.0 3 | tensorflow 4 | numpy==1.13.3 5 | inflect==0.2.5 6 | librosa==0.6.0 7 | scipy==1.0.0 8 | tensorboardX==1.1 9 | Unidecode==1.0.22 10 | pillow 11 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import argparse 28 | import json 29 | import os 30 | import torch 31 | 32 | #=====START: ADDED FOR DISTRIBUTED====== 33 | from distributed import init_distributed, apply_gradient_allreduce, reduce_tensor 34 | from torch.utils.data.distributed import DistributedSampler 35 | #=====END: ADDED FOR DISTRIBUTED====== 36 | 37 | from torch.utils.data import DataLoader 38 | from glow import WaveGlow, WaveGlowLoss 39 | from mel2samp import Mel2Samp 40 | 41 | def load_checkpoint(checkpoint_path, model, optimizer): 42 | assert os.path.isfile(checkpoint_path) 43 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 44 | iteration = checkpoint_dict['iteration'] 45 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 46 | model_for_loading = checkpoint_dict['model'] 47 | model.load_state_dict(model_for_loading.state_dict()) 48 | print("Loaded checkpoint '{}' (iteration {})" .format( 49 | checkpoint_path, iteration)) 50 | return model, optimizer, iteration 51 | 52 | def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): 53 | print("Saving model and optimizer state at iteration {} to {}".format( 54 | iteration, filepath)) 55 | model_for_saving = WaveGlow(**waveglow_config).cuda() 56 | model_for_saving.load_state_dict(model.state_dict()) 57 | torch.save({'model': model_for_saving, 58 | 'iteration': iteration, 59 | 'optimizer': optimizer.state_dict(), 60 | 'learning_rate': learning_rate}, filepath) 61 | 62 | def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate, 63 | sigma, iters_per_checkpoint, batch_size, seed, fp16_run, 64 | checkpoint_path, with_tensorboard): 65 | torch.manual_seed(seed) 66 | torch.cuda.manual_seed(seed) 67 | #=====START: ADDED FOR DISTRIBUTED====== 68 | if num_gpus > 1: 69 | init_distributed(rank, num_gpus, group_name, **dist_config) 70 | #=====END: ADDED FOR DISTRIBUTED====== 71 | 72 | criterion = WaveGlowLoss(sigma) 73 | model = WaveGlow(**waveglow_config).cuda() 74 | 75 | #=====START: ADDED FOR DISTRIBUTED====== 76 | if num_gpus > 1: 77 | model = apply_gradient_allreduce(model) 78 | #=====END: ADDED FOR DISTRIBUTED====== 79 | 80 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 81 | 82 | if fp16_run: 83 | from apex import amp 84 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 85 | 86 | # Load checkpoint if one exists 87 | iteration = 0 88 | if checkpoint_path != "": 89 | model, optimizer, iteration = load_checkpoint(checkpoint_path, model, 90 | optimizer) 91 | iteration += 1 # next iteration is iteration + 1 92 | 93 | trainset = Mel2Samp(**data_config) 94 | # =====START: ADDED FOR DISTRIBUTED====== 95 | train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None 96 | # =====END: ADDED FOR DISTRIBUTED====== 97 | train_loader = DataLoader(trainset, num_workers=1, shuffle=False, 98 | sampler=train_sampler, 99 | batch_size=batch_size, 100 | pin_memory=False, 101 | drop_last=True) 102 | 103 | # Get shared output_directory ready 104 | if rank == 0: 105 | if not os.path.isdir(output_directory): 106 | os.makedirs(output_directory) 107 | os.chmod(output_directory, 0o775) 108 | print("output directory", output_directory) 109 | 110 | if with_tensorboard and rank == 0: 111 | from tensorboardX import SummaryWriter 112 | logger = SummaryWriter(os.path.join(output_directory, 'logs')) 113 | 114 | model.train() 115 | epoch_offset = max(0, int(iteration / len(train_loader))) 116 | # ================ MAIN TRAINNIG LOOP! =================== 117 | for epoch in range(epoch_offset, epochs): 118 | print("Epoch: {}".format(epoch)) 119 | for i, batch in enumerate(train_loader): 120 | model.zero_grad() 121 | 122 | mel, audio = batch 123 | mel = torch.autograd.Variable(mel.cuda()) 124 | audio = torch.autograd.Variable(audio.cuda()) 125 | outputs = model((mel, audio)) 126 | 127 | loss = criterion(outputs) 128 | if num_gpus > 1: 129 | reduced_loss = reduce_tensor(loss.data, num_gpus).item() 130 | else: 131 | reduced_loss = loss.item() 132 | 133 | if fp16_run: 134 | with amp.scale_loss(loss, optimizer) as scaled_loss: 135 | scaled_loss.backward() 136 | else: 137 | loss.backward() 138 | 139 | optimizer.step() 140 | 141 | print("{}:\t{:.9f}".format(iteration, reduced_loss)) 142 | if with_tensorboard and rank == 0: 143 | logger.add_scalar('training_loss', reduced_loss, i + len(train_loader) * epoch) 144 | 145 | if (iteration % iters_per_checkpoint == 0): 146 | if rank == 0: 147 | checkpoint_path = "{}/waveglow_{}".format( 148 | output_directory, iteration) 149 | save_checkpoint(model, optimizer, learning_rate, iteration, 150 | checkpoint_path) 151 | 152 | iteration += 1 153 | 154 | if __name__ == "__main__": 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument('-c', '--config', type=str, 157 | help='JSON file for configuration') 158 | parser.add_argument('-r', '--rank', type=int, default=0, 159 | help='rank of process for distributed') 160 | parser.add_argument('-g', '--group_name', type=str, default='', 161 | help='name of group for distributed') 162 | args = parser.parse_args() 163 | 164 | # Parse configs. Globals nicer in this case 165 | with open(args.config) as f: 166 | data = f.read() 167 | config = json.loads(data) 168 | train_config = config["train_config"] 169 | global data_config 170 | data_config = config["data_config"] 171 | global dist_config 172 | dist_config = config["dist_config"] 173 | global waveglow_config 174 | waveglow_config = config["waveglow_config"] 175 | 176 | num_gpus = torch.cuda.device_count() 177 | if num_gpus > 1: 178 | if args.group_name == '': 179 | print("WARNING: Multiple GPUs detected but no distributed group set") 180 | print("Only running 1 GPU. Use distributed.py for multiple GPUs") 181 | num_gpus = 1 182 | 183 | if num_gpus == 1 and args.rank != 0: 184 | raise Exception("Doing single GPU training on rank > 0") 185 | 186 | torch.backends.cudnn.enabled = True 187 | torch.backends.cudnn.benchmark = False 188 | train(num_gpus, args.rank, args.group_name, **train_config) 189 | -------------------------------------------------------------------------------- /waveglow_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/waveglow/8afb643df59265016af6bd255c7516309d675168/waveglow_logo.png --------------------------------------------------------------------------------