├── egs └── gradtts_n_1000 │ ├── n_1000_steps_259000_ljspeech_hifigan │ ├── init │ ├── LJ016-0117_generated_e2e.wav │ ├── LJ021-0078_generated_e2e.wav │ ├── LJ025-0157_generated_e2e.wav │ ├── LJ032-0164_generated_e2e.wav │ ├── LJ033-0042_generated_e2e.wav │ ├── LJ042-0219_generated_e2e.wav │ ├── LJ043-0016_generated_e2e.wav │ ├── LJ045-0096_generated_e2e.wav │ ├── LJ046-0092_generated_e2e.wav │ ├── LJ049-0022_generated_e2e.wav │ └── LJ050-0118_generated_e2e.wav │ ├── run.sh │ ├── grad_tts_blank.json │ └── inference_waveglow_vocoder.py ├── data └── cmu_dictionary ├── text ├── __pycache__ │ ├── cmudict.cpython-36.pyc │ ├── numbers.cpython-36.pyc │ ├── symbols.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ └── cleaners.cpython-36.pyc ├── symbols.py ├── LICENSE ├── cmudict.py ├── numbers.py ├── cleaners.py └── __init__.py ├── waveglow ├── requirements.txt ├── config.json ├── denoiser.py ├── convert_model.py ├── inference.py ├── mel2samp.py ├── distributed.py ├── train.py ├── glow_old.py └── glow.py ├── monotonic_align ├── setup.py ├── __init__.py └── core.pyx ├── LICENSE ├── README.md ├── audio_processing.py ├── stft.py ├── commons.py ├── train.py ├── utils.py ├── modules.py ├── data_utils.py ├── attentions.py ├── filelists └── ljs_audio_text_val_filelist.txt ├── models.py └── unet.py /egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/init: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/cmu_dictionary: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/data/cmu_dictionary -------------------------------------------------------------------------------- /text/__pycache__/cmudict.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/text/__pycache__/cmudict.cpython-36.pyc -------------------------------------------------------------------------------- /text/__pycache__/numbers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/text/__pycache__/numbers.cpython-36.pyc -------------------------------------------------------------------------------- /text/__pycache__/symbols.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/text/__pycache__/symbols.cpython-36.pyc -------------------------------------------------------------------------------- /text/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/text/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /text/__pycache__/cleaners.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/text/__pycache__/cleaners.cpython-36.pyc -------------------------------------------------------------------------------- /egs/gradtts_n_1000/run.sh: -------------------------------------------------------------------------------- 1 | num_gpu=$1 2 | horovodrun -np $num_gpu -H localhost:$num_gpu python3 -u ../../train.py -c grad_tts_blank.json -l ../../logdir -m gradtts_n_1000 -------------------------------------------------------------------------------- /waveglow/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0 2 | matplotlib==2.1.0 3 | tensorflow 4 | numpy==1.13.3 5 | inflect==0.2.5 6 | librosa==0.6.0 7 | scipy==1.0.0 8 | tensorboardX==1.1 9 | Unidecode==1.0.22 10 | pillow 11 | -------------------------------------------------------------------------------- /egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ016-0117_generated_e2e.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ016-0117_generated_e2e.wav -------------------------------------------------------------------------------- /egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ021-0078_generated_e2e.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ021-0078_generated_e2e.wav -------------------------------------------------------------------------------- /egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ025-0157_generated_e2e.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ025-0157_generated_e2e.wav -------------------------------------------------------------------------------- /egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ032-0164_generated_e2e.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ032-0164_generated_e2e.wav -------------------------------------------------------------------------------- /egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ033-0042_generated_e2e.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ033-0042_generated_e2e.wav -------------------------------------------------------------------------------- /egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ042-0219_generated_e2e.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ042-0219_generated_e2e.wav -------------------------------------------------------------------------------- /egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ043-0016_generated_e2e.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ043-0016_generated_e2e.wav -------------------------------------------------------------------------------- /egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ045-0096_generated_e2e.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ045-0096_generated_e2e.wav -------------------------------------------------------------------------------- /egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ046-0092_generated_e2e.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ046-0092_generated_e2e.wav -------------------------------------------------------------------------------- /egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ049-0022_generated_e2e.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ049-0022_generated_e2e.wav -------------------------------------------------------------------------------- /egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ050-0118_generated_e2e.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WelkinYang/GradTTS/HEAD/egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan/LJ050-0118_generated_e2e.wav -------------------------------------------------------------------------------- /monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | 5 | setup( 6 | name = 'monotonic_align', 7 | ext_modules = cythonize("core.pyx"), 8 | include_dirs=[numpy.get_include()] 9 | ) 10 | -------------------------------------------------------------------------------- /monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .monotonic_align.core import maximum_path_c 4 | 5 | 6 | def maximum_path(value, mask): 7 | """ Cython optimised version. 8 | value: [b, t_x, t_y] 9 | mask: [b, t_x, t_y] 10 | """ 11 | value = value * mask 12 | device = value.device 13 | dtype = value.dtype 14 | value = value.data.cpu().numpy().astype(np.float32) 15 | path = np.zeros_like(value).astype(np.int32) 16 | mask = mask.data.cpu().numpy() 17 | 18 | t_x_max = mask.sum(1)[:, 0].astype(np.int32) 19 | t_y_max = mask.sum(2)[:, 0].astype(np.int32) 20 | maximum_path_c(path, value, t_x_max, t_y_max) 21 | return torch.from_numpy(path).to(device=device, dtype=dtype) 22 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' 7 | from text import cmudict 8 | 9 | _pad = '_' 10 | _punctuation = '!\'(),.:;? ' 11 | _special = '-' 12 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 13 | 14 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 15 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 16 | 17 | # Export all symbols: 18 | symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet 19 | -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 HeyangXue1997 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /waveglow/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_config": { 3 | "fp16_run": true, 4 | "output_directory": "checkpoints", 5 | "epochs": 100000, 6 | "learning_rate": 1e-4, 7 | "sigma": 1.0, 8 | "iters_per_checkpoint": 2000, 9 | "batch_size": 12, 10 | "seed": 1234, 11 | "checkpoint_path": "", 12 | "with_tensorboard": false 13 | }, 14 | "data_config": { 15 | "training_files": "train_files.txt", 16 | "segment_length": 16000, 17 | "sampling_rate": 22050, 18 | "filter_length": 1024, 19 | "hop_length": 256, 20 | "win_length": 1024, 21 | "mel_fmin": 0.0, 22 | "mel_fmax": 8000.0 23 | }, 24 | "dist_config": { 25 | "dist_backend": "nccl", 26 | "dist_url": "tcp://localhost:54321" 27 | }, 28 | 29 | "waveglow_config": { 30 | "n_mel_channels": 80, 31 | "n_flows": 12, 32 | "n_group": 8, 33 | "n_early_every": 4, 34 | "n_early_size": 2, 35 | "WN_config": { 36 | "n_layers": 8, 37 | "n_channels": 256, 38 | "kernel_size": 3 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport numpy as np 3 | cimport cython 4 | from cython.parallel import prange 5 | 6 | 7 | @cython.boundscheck(False) 8 | @cython.wraparound(False) 9 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: 10 | cdef int x 11 | cdef int y 12 | cdef float v_prev 13 | cdef float v_cur 14 | cdef float tmp 15 | cdef int index = t_x - 1 16 | 17 | for y in range(t_y): 18 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 19 | if x == y: 20 | v_cur = max_neg_val 21 | else: 22 | v_cur = value[x, y-1] 23 | if x == 0: 24 | if y == 0: 25 | v_prev = 0. 26 | else: 27 | v_prev = max_neg_val 28 | else: 29 | v_prev = value[x-1, y-1] 30 | value[x, y] = max(v_cur, v_prev) + value[x, y] 31 | 32 | for y in range(t_y - 1, -1, -1): 33 | path[index, y] = 1 34 | if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): 35 | index = index - 1 36 | 37 | 38 | @cython.boundscheck(False) 39 | @cython.wraparound(False) 40 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: 41 | cdef int b = values.shape[0] 42 | 43 | cdef int i 44 | for i in prange(b, nogil=True): 45 | maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) 46 | -------------------------------------------------------------------------------- /waveglow/denoiser.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('tacotron2') 3 | import torch 4 | from layers import STFT 5 | 6 | 7 | class Denoiser(torch.nn.Module): 8 | """ Removes model bias from audio produced with waveglow """ 9 | 10 | def __init__(self, waveglow, filter_length=1024, n_overlap=4, 11 | win_length=1024, mode='zeros'): 12 | super(Denoiser, self).__init__() 13 | self.stft = STFT(filter_length=filter_length, 14 | hop_length=int(filter_length/n_overlap), 15 | win_length=win_length).cuda() 16 | if mode == 'zeros': 17 | mel_input = torch.zeros( 18 | (1, 80, 88), 19 | dtype=waveglow.upsample.weight.dtype, 20 | device=waveglow.upsample.weight.device) 21 | elif mode == 'normal': 22 | mel_input = torch.randn( 23 | (1, 80, 88), 24 | dtype=waveglow.upsample.weight.dtype, 25 | device=waveglow.upsample.weight.device) 26 | else: 27 | raise Exception("Mode {} if not supported".format(mode)) 28 | 29 | with torch.no_grad(): 30 | bias_audio = waveglow.infer(mel_input, sigma=0.0).float() 31 | bias_spec, _ = self.stft.transform(bias_audio) 32 | 33 | self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None]) 34 | 35 | def forward(self, audio, strength=0.1): 36 | audio_spec, audio_angles = self.stft.transform(audio.cuda().float()) 37 | audio_spec_denoised = audio_spec - self.bias_spec * strength 38 | audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) 39 | audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles) 40 | return audio_denoised 41 | -------------------------------------------------------------------------------- /egs/gradtts_n_1000/grad_tts_blank.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "use_cuda": true, 4 | "log_interval": 20, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 1e0, 8 | "betas": [0.9, 0.98], 9 | "eps": 1e-9, 10 | "warmup_steps": 4000, 11 | "scheduler": "noam", 12 | "batch_size": 2, 13 | "ddi": true, 14 | "fp16_run": false 15 | }, 16 | "data": { 17 | "load_mel_from_disk": false, 18 | "audio_path_prefix": "../../", 19 | "training_files":"../../filelists/ljs_audio_text_train_filelist.txt", 20 | "validation_files":"../../filelists/ljs_audio_text_val_filelist.txt", 21 | "text_cleaners":["english_cleaners"], 22 | "max_wav_value": 32768.0, 23 | "sampling_rate": 22050, 24 | "filter_length": 1024, 25 | "hop_length": 256, 26 | "win_length": 1024, 27 | "n_mel_channels": 80, 28 | "mel_fmin": 0.0, 29 | "mel_fmax": 8000.0, 30 | "add_noise": true, 31 | "add_blank": true, 32 | "cmudict_path": "../../data/cmu_dictionary" 33 | }, 34 | "model": { 35 | "hidden_channels": 192, 36 | "filter_channels": 768, 37 | "filter_channels_dp": 256, 38 | "kernel_size": 3, 39 | "p_dropout": 0.1, 40 | "n_blocks_dec": 12, 41 | "n_layers_enc": 6, 42 | "n_heads": 2, 43 | "p_dropout_dec": 0.05, 44 | "dilation_rate": 1, 45 | "kernel_size_dec": 5, 46 | "n_block_layers": 4, 47 | "n_sqz": 2, 48 | "prenet": true, 49 | "mean_only": true, 50 | "hidden_channels_enc": 192, 51 | "hidden_channels_dec": 192, 52 | "window_size": 4, 53 | "dec_dim_mults": [1, 2, 4], 54 | "dec_groups": 8, 55 | "dec_unet_channels": 64, 56 | "dec_unet_in_channels": 2, 57 | "dec_with_time_emb": true, 58 | "beta_0" :0.05, 59 | "beta_1": 20, 60 | "N": 1000, 61 | "T": 1 62 | } 63 | } 64 | 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GradTTS 2 | ## Unofficial Pytorch implementation of "Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech" ([arxiv](https://arxiv.org/abs/2105.06337)) 3 | 4 | ## About this repo 5 | This is an unofficial implementation of GradTTS. We created this project based on GlowTTS (https://github.com/jaywalnut310/glow-tts). We replace the GlowDecoder with DiffusionDecoder which follows the settings of the original paper. In addition, we also replace torch.distributed with horovod for convenience and we don't use fp16 now. 6 | 7 | ## Updates 8 | 9 | 2021/07/28: [LJSpeech Samples](https://github.com/WelkinYang/GradTTS/tree/main/egs/gradtts_n_1000/n_1000_steps_259000_ljspeech_hifigan) uploaded which has the same performance as the original paper's demo. 10 | 11 | ## Training and inference 12 | Please go to egs/ folder, and see run.sh and inference_waveglow_vocoder.py for example use. Before training, please download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/), then rename or create a link to the dataset folder: `ln -s /path/to/LJSpeech-1.1/wavs DUMMY`. And build Monotonic Alignment Search Code (Cython): `cd monotonic_align; python setup.py build_ext --inplace`. Before inference, you should download waveglow checkpoint from [download_link](https://drive.google.com/file/d/1rpK8CzAAirq9sWZhe9nlfvxMF1dRgFbF/view) and put it into the waveglow folder. 13 | 14 | ## Reference Materials 15 | [Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech](https://arxiv.org/abs/2105.06337) 16 | 17 | [GlowTTS](https://github.com/jaywalnut310/glow-tts) 18 | 19 | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) 20 | 21 | [score_sde_pytorch](https://github.com/yang-song/score_sde_pytorch) 22 | 23 | [denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch) 24 | 25 | ## Authors 26 | Heyang Xue(https://github.com/WelkinYang) and Qicong Xie(https://github.com/QicongXie) 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 14 | ] 15 | 16 | _valid_symbol_set = set(valid_symbols) 17 | 18 | 19 | class CMUDict: 20 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 21 | def __init__(self, file_or_path, keep_ambiguous=True): 22 | if isinstance(file_or_path, str): 23 | with open(file_or_path, encoding='latin-1') as f: 24 | entries = _parse_cmudict(f) 25 | else: 26 | entries = _parse_cmudict(file_or_path) 27 | if not keep_ambiguous: 28 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 29 | self._entries = entries 30 | 31 | 32 | def __len__(self): 33 | return len(self._entries) 34 | 35 | 36 | def lookup(self, word): 37 | '''Returns list of ARPAbet pronunciations of the given word.''' 38 | return self._entries.get(word.upper()) 39 | 40 | 41 | 42 | _alt_re = re.compile(r'\([0-9]+\)') 43 | 44 | 45 | def _parse_cmudict(file): 46 | cmudict = {} 47 | for line in file: 48 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 49 | parts = line.split(' ') 50 | word = re.sub(_alt_re, '', parts[0]) 51 | pronunciation = _get_pronunciation(parts[1]) 52 | if pronunciation: 53 | if word in cmudict: 54 | cmudict[word].append(pronunciation) 55 | else: 56 | cmudict[word] = [pronunciation] 57 | return cmudict 58 | 59 | 60 | def _get_pronunciation(s): 61 | parts = s.strip().split(' ') 62 | for part in parts: 63 | if part not in _valid_symbol_set: 64 | return None 65 | return ' '.join(parts) 66 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 9 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 10 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 11 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 12 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 13 | _number_re = re.compile(r'[0-9]+') 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(',', '') 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace('.', ' point ') 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split('.') 27 | if len(parts) > 2: 28 | return match + ' dollars' # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 33 | cent_unit = 'cent' if cents == 1 else 'cents' 34 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 37 | return '%s %s' % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = 'cent' if cents == 1 else 'cents' 40 | return '%s %s' % (cents, cent_unit) 41 | else: 42 | return 'zero dollars' 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return 'two thousand' 54 | elif num > 2000 and num < 2010: 55 | return 'two thousand ' + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + ' hundred' 58 | else: 59 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 60 | else: 61 | return _inflect.number_to_words(num, andword='') 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r'\1 pounds', text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | import re 16 | from unidecode import unidecode 17 | from .numbers import normalize_numbers 18 | 19 | 20 | # Regular expression matching whitespace: 21 | _whitespace_re = re.compile(r'\s+') 22 | 23 | # List of (regular expression, replacement) pairs for abbreviations: 24 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 25 | ('mrs', 'misess'), 26 | ('mr', 'mister'), 27 | ('dr', 'doctor'), 28 | ('st', 'saint'), 29 | ('co', 'company'), 30 | ('jr', 'junior'), 31 | ('maj', 'major'), 32 | ('gen', 'general'), 33 | ('drs', 'doctors'), 34 | ('rev', 'reverend'), 35 | ('lt', 'lieutenant'), 36 | ('hon', 'honorable'), 37 | ('sgt', 'sergeant'), 38 | ('capt', 'captain'), 39 | ('esq', 'esquire'), 40 | ('ltd', 'limited'), 41 | ('col', 'colonel'), 42 | ('ft', 'fort'), 43 | ]] 44 | 45 | 46 | def expand_abbreviations(text): 47 | for regex, replacement in _abbreviations: 48 | text = re.sub(regex, replacement, text) 49 | return text 50 | 51 | 52 | def expand_numbers(text): 53 | return normalize_numbers(text) 54 | 55 | 56 | def lowercase(text): 57 | return text.lower() 58 | 59 | 60 | def collapse_whitespace(text): 61 | return re.sub(_whitespace_re, ' ', text) 62 | 63 | 64 | def convert_to_ascii(text): 65 | return unidecode(text) 66 | 67 | 68 | def basic_cleaners(text): 69 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 70 | text = lowercase(text) 71 | text = collapse_whitespace(text) 72 | return text 73 | 74 | 75 | def transliteration_cleaners(text): 76 | '''Pipeline for non-English text that transliterates to ASCII.''' 77 | text = convert_to_ascii(text) 78 | text = lowercase(text) 79 | text = collapse_whitespace(text) 80 | return text 81 | 82 | 83 | def english_cleaners(text): 84 | '''Pipeline for English text, including number and abbreviation expansion.''' 85 | text = convert_to_ascii(text) 86 | text = lowercase(text) 87 | text = expand_numbers(text) 88 | text = expand_abbreviations(text) 89 | text = collapse_whitespace(text) 90 | return text 91 | -------------------------------------------------------------------------------- /audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | 7 | def window_sumsquare(window, n_frames, hop_length=200, win_length=800, 8 | n_fft=800, dtype=np.float32, norm=None): 9 | """ 10 | # from librosa 0.6 11 | Compute the sum-square envelope of a window function at a given hop length. 12 | 13 | This is used to estimate modulation effects induced by windowing 14 | observations in short-time fourier transforms. 15 | 16 | Parameters 17 | ---------- 18 | window : string, tuple, number, callable, or list-like 19 | Window specification, as in `get_window` 20 | 21 | n_frames : int > 0 22 | The number of analysis frames 23 | 24 | hop_length : int > 0 25 | The number of samples to advance between frames 26 | 27 | win_length : [optional] 28 | The length of the window function. By default, this matches `n_fft`. 29 | 30 | n_fft : int > 0 31 | The length of each analysis frame. 32 | 33 | dtype : np.dtype 34 | The data type of the output 35 | 36 | Returns 37 | ------- 38 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 39 | The sum-squared envelope of the window function 40 | """ 41 | if win_length is None: 42 | win_length = n_fft 43 | 44 | n = n_fft + hop_length * (n_frames - 1) 45 | x = np.zeros(n, dtype=dtype) 46 | 47 | # Compute the squared window at the desired length 48 | win_sq = get_window(window, win_length, fftbins=True) 49 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 50 | win_sq = librosa_util.pad_center(win_sq, n_fft) 51 | 52 | # Fill the envelope 53 | for i in range(n_frames): 54 | sample = i * hop_length 55 | x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 56 | return x 57 | 58 | 59 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 60 | """ 61 | PARAMS 62 | ------ 63 | magnitudes: spectrogram magnitudes 64 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 65 | """ 66 | 67 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 68 | angles = angles.astype(np.float32) 69 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 70 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 71 | 72 | for i in range(n_iters): 73 | _, angles = stft_fn.transform(signal) 74 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 75 | return signal 76 | 77 | 78 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 79 | """ 80 | PARAMS 81 | ------ 82 | C: compression factor 83 | """ 84 | return torch.log(torch.clamp(x, min=clip_val) * C) 85 | 86 | 87 | def dynamic_range_decompression(x, C=1): 88 | """ 89 | PARAMS 90 | ------ 91 | C: compression factor used to compress 92 | """ 93 | return torch.exp(x) / C 94 | -------------------------------------------------------------------------------- /waveglow/convert_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import torch 4 | 5 | def _check_model_old_version(model): 6 | if hasattr(model.WN[0], 'res_layers') or hasattr(model.WN[0], 'cond_layers'): 7 | return True 8 | else: 9 | return False 10 | 11 | 12 | def _update_model_res_skip(old_model, new_model): 13 | for idx in range(0, len(new_model.WN)): 14 | wavenet = new_model.WN[idx] 15 | n_channels = wavenet.n_channels 16 | n_layers = wavenet.n_layers 17 | wavenet.res_skip_layers = torch.nn.ModuleList() 18 | for i in range(0, n_layers): 19 | if i < n_layers - 1: 20 | res_skip_channels = 2*n_channels 21 | else: 22 | res_skip_channels = n_channels 23 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) 24 | skip_layer = torch.nn.utils.remove_weight_norm(wavenet.skip_layers[i]) 25 | if i < n_layers - 1: 26 | res_layer = torch.nn.utils.remove_weight_norm(wavenet.res_layers[i]) 27 | res_skip_layer.weight = torch.nn.Parameter(torch.cat([res_layer.weight, skip_layer.weight])) 28 | res_skip_layer.bias = torch.nn.Parameter(torch.cat([res_layer.bias, skip_layer.bias])) 29 | else: 30 | res_skip_layer.weight = torch.nn.Parameter(skip_layer.weight) 31 | res_skip_layer.bias = torch.nn.Parameter(skip_layer.bias) 32 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 33 | wavenet.res_skip_layers.append(res_skip_layer) 34 | del wavenet.res_layers 35 | del wavenet.skip_layers 36 | 37 | def _update_model_cond(old_model, new_model): 38 | for idx in range(0, len(new_model.WN)): 39 | wavenet = new_model.WN[idx] 40 | n_channels = wavenet.n_channels 41 | n_layers = wavenet.n_layers 42 | n_mel_channels = wavenet.cond_layers[0].weight.shape[1] 43 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1) 44 | cond_layer_weight = [] 45 | cond_layer_bias = [] 46 | for i in range(0, n_layers): 47 | _cond_layer = torch.nn.utils.remove_weight_norm(wavenet.cond_layers[i]) 48 | cond_layer_weight.append(_cond_layer.weight) 49 | cond_layer_bias.append(_cond_layer.bias) 50 | cond_layer.weight = torch.nn.Parameter(torch.cat(cond_layer_weight)) 51 | cond_layer.bias = torch.nn.Parameter(torch.cat(cond_layer_bias)) 52 | cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 53 | wavenet.cond_layer = cond_layer 54 | del wavenet.cond_layers 55 | 56 | def update_model(old_model): 57 | if not _check_model_old_version(old_model): 58 | return old_model 59 | new_model = copy.deepcopy(old_model) 60 | if hasattr(old_model.WN[0], 'res_layers'): 61 | _update_model_res_skip(old_model, new_model) 62 | if hasattr(old_model.WN[0], 'cond_layers'): 63 | _update_model_cond(old_model, new_model) 64 | return new_model 65 | 66 | if __name__ == '__main__': 67 | old_model_path = sys.argv[1] 68 | new_model_path = sys.argv[2] 69 | model = torch.load(old_model_path, map_location='cpu') 70 | model['model'] = update_model(model['model']) 71 | torch.save(model, new_model_path) 72 | 73 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from text import cleaners 4 | from text.symbols import symbols 5 | 6 | 7 | # Mappings from symbol to numeric ID and vice versa: 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 10 | 11 | # Regular expression matching text enclosed in curly braces: 12 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 13 | 14 | 15 | def get_arpabet(word, dictionary): 16 | word_arpabet = dictionary.lookup(word) 17 | if word_arpabet is not None: 18 | return "{" + word_arpabet[0] + "}" 19 | else: 20 | return word 21 | 22 | 23 | def text_to_sequence(text, cleaner_names, dictionary=None): 24 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 25 | 26 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 27 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 28 | 29 | Args: 30 | text: string to convert to a sequence 31 | cleaner_names: names of the cleaner functions to run the text through 32 | dictionary: arpabet class with arpabet dictionary 33 | 34 | Returns: 35 | List of integers corresponding to the symbols in the text 36 | ''' 37 | sequence = [] 38 | 39 | space = _symbols_to_sequence(' ') 40 | # Check for curly braces and treat their contents as ARPAbet: 41 | while len(text): 42 | m = _curly_re.match(text) 43 | if not m: 44 | clean_text = _clean_text(text, cleaner_names) 45 | if dictionary is not None: 46 | clean_text = [get_arpabet(w, dictionary) for w in clean_text.split(" ")] 47 | for i in range(len(clean_text)): 48 | t = clean_text[i] 49 | if t.startswith("{"): 50 | sequence += _arpabet_to_sequence(t[1:-1]) 51 | else: 52 | sequence += _symbols_to_sequence(t) 53 | sequence += space 54 | else: 55 | sequence += _symbols_to_sequence(clean_text) 56 | break 57 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 58 | sequence += _arpabet_to_sequence(m.group(2)) 59 | text = m.group(3) 60 | 61 | # remove trailing space 62 | if dictionary is not None: 63 | sequence = sequence[:-1] if sequence[-1] == space[0] else sequence 64 | return sequence 65 | 66 | 67 | def sequence_to_text(sequence): 68 | '''Converts a sequence of IDs back to a string''' 69 | result = '' 70 | for symbol_id in sequence: 71 | if symbol_id in _id_to_symbol: 72 | s = _id_to_symbol[symbol_id] 73 | # Enclose ARPAbet back in curly braces: 74 | if len(s) > 1 and s[0] == '@': 75 | s = '{%s}' % s[1:] 76 | result += s 77 | return result.replace('}{', ' ') 78 | 79 | 80 | def _clean_text(text, cleaner_names): 81 | for name in cleaner_names: 82 | cleaner = getattr(cleaners, name) 83 | if not cleaner: 84 | raise Exception('Unknown cleaner: %s' % name) 85 | text = cleaner(text) 86 | return text 87 | 88 | 89 | def _symbols_to_sequence(symbols): 90 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 91 | 92 | 93 | def _arpabet_to_sequence(text): 94 | return _symbols_to_sequence(['@' + s for s in text.split()]) 95 | 96 | 97 | def _should_keep_symbol(s): 98 | return s in _symbol_to_id and s is not '_' and s is not '~' 99 | -------------------------------------------------------------------------------- /egs/gradtts_n_1000/inference_waveglow_vocoder.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | import sys 4 | sys.path.append('../../waveglow/') 5 | sys.path.append('../../') 6 | import librosa 7 | import numpy as np 8 | import os 9 | import glob 10 | import json 11 | 12 | import torch 13 | from text import text_to_sequence, cmudict 14 | from text.symbols import symbols 15 | import commons 16 | import attentions 17 | import modules 18 | import models 19 | import utils 20 | import soundfile as sf 21 | 22 | def save_wav(wav, path, sample_rate, norm=False): 23 | if norm: 24 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 25 | wavfile.write(path, sample_rate, wav.astype(np.int16)) 26 | else: 27 | sf.write(path, wav, sample_rate) 28 | 29 | # load WaveGlow 30 | waveglow_path = '../../waveglow/waveglow_256channels_universal_v5.pt' # or change to the latest version of the pretrained WaveGlow. 31 | waveglow = torch.load(waveglow_path)['model'] 32 | for k, m in waveglow.named_modules(): 33 | m._non_persistent_buffers_set = set() 34 | waveglow = waveglow.remove_weightnorm(waveglow) 35 | _ = waveglow.cuda().eval() 36 | 37 | # If you are using your own trained model 38 | model_dir = sys.argv[1] 39 | test_files_path = sys.argv[2] 40 | 41 | hps = utils.get_hparams_from_dir(model_dir) 42 | checkpoint_path = utils.latest_checkpoint_path(model_dir) 43 | 44 | # If you are using a provided pretrained model 45 | # hps = utils.get_hparams_from_file("./configs/any_config_file.json") 46 | # checkpoint_path = "/path/to/pretrained_model" 47 | 48 | model = models.DiffusionGenerator( 49 | len(symbols) + getattr(hps.data, "add_blank", False), 50 | enc_out_channels=hps.data.n_mel_channels, 51 | **hps.model).to("cuda") 52 | 53 | utils.load_checkpoint(checkpoint_path, model) 54 | _ = model.eval() 55 | 56 | cmu_dict = cmudict.CMUDict(hps.data.cmudict_path) 57 | 58 | # normalizing & type casting 59 | def normalize_audio(x, max_wav_value=hps.data.max_wav_value): 60 | return np.clip((x / np.abs(x).max()) * max_wav_value, -32768, 32767).astype("int16") 61 | 62 | print(test_files_path) 63 | test_lines = open(test_files_path, 'r', encoding='utf-8').readlines() 64 | for line in test_lines: 65 | file_name = os.path.basename(line.strip().split('|')[0]) 66 | print(file_name) 67 | 68 | tst_stn = line.strip().split('|')[-1] 69 | 70 | if getattr(hps.data, "add_blank", False): 71 | text_norm = text_to_sequence(tst_stn.strip(), ['english_cleaners'], cmu_dict) 72 | text_norm = commons.intersperse(text_norm, len(symbols)) 73 | else: # If not using "add_blank" option during training, adding spaces at the beginning and the end of utterance improves quality 74 | tst_stn = " " + tst_stn.strip() + " " 75 | text_norm = text_to_sequence(tst_stn.strip(), ['english_cleaners'], cmu_dict) 76 | sequence = np.array(text_norm)[None, :] 77 | print("".join([symbols[c] if c < len(symbols) else "" for c in sequence[0]])) 78 | x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long() 79 | x_tst_lengths = torch.tensor([x_tst.shape[1]]).cuda() 80 | 81 | 82 | with torch.no_grad(): 83 | length_scale = 1.0 84 | (y_gen_tst, *_), *_, (attn_gen, *_) = model(x_tst, x_tst_lengths, gen=True, length_scale=length_scale) 85 | try: 86 | audio = waveglow.infer(y_gen_tst.half(), sigma=.666) 87 | except: 88 | audio = waveglow.infer(y_gen_tst, sigma=.666) 89 | 90 | save_wav(normalize_audio(audio[0].clamp(-1,1).data.cpu().float().numpy()), os.path.join('test_outputs', file_name), sample_rate=hps.data.sampling_rate) 91 | -------------------------------------------------------------------------------- /waveglow/inference.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | # ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import os 28 | from scipy.io.wavfile import write 29 | import torch 30 | from mel2samp import files_to_list, MAX_WAV_VALUE 31 | from denoiser import Denoiser 32 | 33 | 34 | def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16, 35 | denoiser_strength): 36 | mel_files = files_to_list(mel_files) 37 | waveglow = torch.load(waveglow_path)['model'] 38 | waveglow = waveglow.remove_weightnorm(waveglow) 39 | waveglow.cuda().eval() 40 | if is_fp16: 41 | from apex import amp 42 | waveglow, _ = amp.initialize(waveglow, [], opt_level="O3") 43 | 44 | if denoiser_strength > 0: 45 | denoiser = Denoiser(waveglow).cuda() 46 | 47 | for i, file_path in enumerate(mel_files): 48 | file_name = os.path.splitext(os.path.basename(file_path))[0] 49 | mel = torch.load(file_path) 50 | mel = torch.autograd.Variable(mel.cuda()) 51 | mel = torch.unsqueeze(mel, 0) 52 | mel = mel.half() if is_fp16 else mel 53 | with torch.no_grad(): 54 | audio = waveglow.infer(mel, sigma=sigma) 55 | if denoiser_strength > 0: 56 | audio = denoiser(audio, denoiser_strength) 57 | audio = audio * MAX_WAV_VALUE 58 | audio = audio.squeeze() 59 | audio = audio.cpu().numpy() 60 | audio = audio.astype('int16') 61 | audio_path = os.path.join( 62 | output_dir, "{}_synthesis.wav".format(file_name)) 63 | write(audio_path, sampling_rate, audio) 64 | print(audio_path) 65 | 66 | 67 | if __name__ == "__main__": 68 | import argparse 69 | 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('-f', "--filelist_path", required=True) 72 | parser.add_argument('-w', '--waveglow_path', 73 | help='Path to waveglow decoder checkpoint with model') 74 | parser.add_argument('-o', "--output_dir", required=True) 75 | parser.add_argument("-s", "--sigma", default=1.0, type=float) 76 | parser.add_argument("--sampling_rate", default=22050, type=int) 77 | parser.add_argument("--is_fp16", action="store_true") 78 | parser.add_argument("-d", "--denoiser_strength", default=0.0, type=float, 79 | help='Removes model bias. Start with 0.1 and adjust') 80 | 81 | args = parser.parse_args() 82 | 83 | main(args.filelist_path, args.waveglow_path, args.sigma, args.output_dir, 84 | args.sampling_rate, args.is_fp16, args.denoiser_strength) 85 | -------------------------------------------------------------------------------- /waveglow/mel2samp.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # *****************************************************************************\ 27 | import os 28 | import random 29 | import argparse 30 | import json 31 | import torch 32 | import torch.utils.data 33 | import sys 34 | from scipy.io.wavfile import read 35 | 36 | # We're using the audio processing from TacoTron2 to make sure it matches 37 | sys.path.insert(0, 'tacotron2') 38 | from tacotron2.layers import TacotronSTFT 39 | 40 | MAX_WAV_VALUE = 32768.0 41 | 42 | def files_to_list(filename): 43 | """ 44 | Takes a text file of filenames and makes a list of filenames 45 | """ 46 | with open(filename, encoding='utf-8') as f: 47 | files = f.readlines() 48 | 49 | files = [f.rstrip() for f in files] 50 | return files 51 | 52 | def load_wav_to_torch(full_path): 53 | """ 54 | Loads wavdata into torch array 55 | """ 56 | sampling_rate, data = read(full_path) 57 | return torch.from_numpy(data).float(), sampling_rate 58 | 59 | 60 | class Mel2Samp(torch.utils.data.Dataset): 61 | """ 62 | This is the main class that calculates the spectrogram and returns the 63 | spectrogram, audio pair. 64 | """ 65 | def __init__(self, training_files, segment_length, filter_length, 66 | hop_length, win_length, sampling_rate, mel_fmin, mel_fmax): 67 | self.audio_files = files_to_list(training_files) 68 | random.seed(1234) 69 | random.shuffle(self.audio_files) 70 | self.stft = TacotronSTFT(filter_length=filter_length, 71 | hop_length=hop_length, 72 | win_length=win_length, 73 | sampling_rate=sampling_rate, 74 | mel_fmin=mel_fmin, mel_fmax=mel_fmax) 75 | self.segment_length = segment_length 76 | self.sampling_rate = sampling_rate 77 | 78 | def get_mel(self, audio): 79 | audio_norm = audio / MAX_WAV_VALUE 80 | audio_norm = audio_norm.unsqueeze(0) 81 | audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) 82 | melspec = self.stft.mel_spectrogram(audio_norm) 83 | melspec = torch.squeeze(melspec, 0) 84 | return melspec 85 | 86 | def __getitem__(self, index): 87 | # Read audio 88 | filename = self.audio_files[index] 89 | audio, sampling_rate = load_wav_to_torch(filename) 90 | if sampling_rate != self.sampling_rate: 91 | raise ValueError("{} SR doesn't match target {} SR".format( 92 | sampling_rate, self.sampling_rate)) 93 | 94 | # Take segment 95 | if audio.size(0) >= self.segment_length: 96 | max_audio_start = audio.size(0) - self.segment_length 97 | audio_start = random.randint(0, max_audio_start) 98 | audio = audio[audio_start:audio_start+self.segment_length] 99 | else: 100 | audio = torch.nn.functional.pad(audio, (0, self.segment_length - audio.size(0)), 'constant').data 101 | 102 | mel = self.get_mel(audio) 103 | audio = audio / MAX_WAV_VALUE 104 | 105 | return (mel, audio) 106 | 107 | def __len__(self): 108 | return len(self.audio_files) 109 | 110 | # =================================================================== 111 | # Takes directory of clean audio and makes directory of spectrograms 112 | # Useful for making test sets 113 | # =================================================================== 114 | if __name__ == "__main__": 115 | # Get defaults so it can work with no Sacred 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('-f', "--filelist_path", required=True) 118 | parser.add_argument('-c', '--config', type=str, 119 | help='JSON file for configuration') 120 | parser.add_argument('-o', '--output_dir', type=str, 121 | help='Output directory') 122 | args = parser.parse_args() 123 | 124 | with open(args.config) as f: 125 | data = f.read() 126 | data_config = json.loads(data)["data_config"] 127 | mel2samp = Mel2Samp(**data_config) 128 | 129 | filepaths = files_to_list(args.filelist_path) 130 | 131 | # Make directory if it doesn't exist 132 | if not os.path.isdir(args.output_dir): 133 | os.makedirs(args.output_dir) 134 | os.chmod(args.output_dir, 0o775) 135 | 136 | for filepath in filepaths: 137 | audio, sr = load_wav_to_torch(filepath) 138 | melspectrogram = mel2samp.get_mel(audio) 139 | filename = os.path.basename(filepath) 140 | new_filepath = args.output_dir + '/' + filename + '.pt' 141 | print(new_filepath) 142 | torch.save(melspectrogram, new_filepath) 143 | -------------------------------------------------------------------------------- /stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) 2017, Prem Seetharaman 5 | All rights reserved. 6 | 7 | * Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, this 14 | list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from this 19 | software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | """ 32 | 33 | import torch 34 | import numpy as np 35 | import torch.nn.functional as F 36 | from torch.autograd import Variable 37 | from scipy.signal import get_window 38 | from librosa.util import pad_center, tiny 39 | from librosa import stft, istft 40 | from audio_processing import window_sumsquare 41 | 42 | 43 | class STFT(torch.nn.Module): 44 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 45 | def __init__(self, filter_length=800, hop_length=200, win_length=800, 46 | window='hann'): 47 | super(STFT, self).__init__() 48 | self.filter_length = filter_length 49 | self.hop_length = hop_length 50 | self.win_length = win_length 51 | self.window = window 52 | self.forward_transform = None 53 | scale = self.filter_length / self.hop_length 54 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 55 | 56 | cutoff = int((self.filter_length / 2 + 1)) 57 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 58 | np.imag(fourier_basis[:cutoff, :])]) 59 | 60 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 61 | inverse_basis = torch.FloatTensor( 62 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 63 | 64 | if window is not None: 65 | assert(filter_length >= win_length) 66 | # get window and zero center pad it to filter_length 67 | fft_window = get_window(window, win_length, fftbins=True) 68 | fft_window = pad_center(fft_window, filter_length) 69 | fft_window = torch.from_numpy(fft_window).float() 70 | 71 | # window the bases 72 | forward_basis *= fft_window 73 | inverse_basis *= fft_window 74 | 75 | self.register_buffer('forward_basis', forward_basis.float()) 76 | self.register_buffer('inverse_basis', inverse_basis.float()) 77 | 78 | def transform(self, input_data): 79 | num_batches = input_data.size(0) 80 | num_samples = input_data.size(1) 81 | 82 | self.num_samples = num_samples 83 | 84 | if input_data.device.type == "cuda": 85 | # similar to librosa, reflect-pad the input 86 | input_data = input_data.view(num_batches, 1, num_samples) 87 | input_data = F.pad( 88 | input_data.unsqueeze(1), 89 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 90 | mode='reflect') 91 | input_data = input_data.squeeze(1) 92 | 93 | forward_transform = F.conv1d( 94 | input_data, 95 | self.forward_basis, 96 | stride=self.hop_length, 97 | padding=0) 98 | 99 | cutoff = int((self.filter_length / 2) + 1) 100 | real_part = forward_transform[:, :cutoff, :] 101 | imag_part = forward_transform[:, cutoff:, :] 102 | else: 103 | x = input_data.detach().numpy() 104 | real_part = [] 105 | imag_part = [] 106 | for y in x: 107 | y_ = stft(y, self.filter_length, self.hop_length, self.win_length, self.window) 108 | real_part.append(y_.real[None,:,:]) 109 | imag_part.append(y_.imag[None,:,:]) 110 | real_part = np.concatenate(real_part, 0) 111 | imag_part = np.concatenate(imag_part, 0) 112 | 113 | real_part = torch.from_numpy(real_part).to(input_data.dtype) 114 | imag_part = torch.from_numpy(imag_part).to(input_data.dtype) 115 | 116 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 117 | phase = torch.atan2(imag_part.data, real_part.data) 118 | 119 | return magnitude, phase 120 | 121 | def inverse(self, magnitude, phase): 122 | recombine_magnitude_phase = torch.cat( 123 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 124 | 125 | if magnitude.device.type == "cuda": 126 | inverse_transform = F.conv_transpose1d( 127 | recombine_magnitude_phase, 128 | self.inverse_basis, 129 | stride=self.hop_length, 130 | padding=0) 131 | 132 | if self.window is not None: 133 | window_sum = window_sumsquare( 134 | self.window, magnitude.size(-1), hop_length=self.hop_length, 135 | win_length=self.win_length, n_fft=self.filter_length, 136 | dtype=np.float32) 137 | # remove modulation effects 138 | approx_nonzero_indices = torch.from_numpy( 139 | np.where(window_sum > tiny(window_sum))[0]) 140 | window_sum = torch.from_numpy(window_sum).to(inverse_transform.device) 141 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 142 | 143 | # scale by hop ratio 144 | inverse_transform *= float(self.filter_length) / self.hop_length 145 | 146 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 147 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] 148 | inverse_transform = inverse_transform.squeeze(1) 149 | else: 150 | x_org = recombine_magnitude_phase.detach().numpy() 151 | n_b, n_f, n_t = x_org.shape 152 | x = np.empty([n_b, n_f//2, n_t], dtype=np.complex64) 153 | x.real = x_org[:,:n_f//2] 154 | x.imag = x_org[:,n_f//2:] 155 | inverse_transform = [] 156 | for y in x: 157 | y_ = istft(y, self.hop_length, self.win_length, self.window) 158 | inverse_transform.append(y_[None,:]) 159 | inverse_transform = np.concatenate(inverse_transform, 0) 160 | inverse_transform = torch.from_numpy(inverse_transform).to(recombine_magnitude_phase.dtype) 161 | 162 | return inverse_transform 163 | 164 | def forward(self, input_data): 165 | self.magnitude, self.phase = self.transform(input_data) 166 | reconstruction = self.inverse(self.magnitude, self.phase) 167 | return reconstruction 168 | -------------------------------------------------------------------------------- /waveglow/distributed.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import os 28 | import sys 29 | import time 30 | import subprocess 31 | import argparse 32 | 33 | import torch 34 | import torch.distributed as dist 35 | from torch.autograd import Variable 36 | 37 | def reduce_tensor(tensor, num_gpus): 38 | rt = tensor.clone() 39 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 40 | rt /= num_gpus 41 | return rt 42 | 43 | def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): 44 | assert torch.cuda.is_available(), "Distributed mode requires CUDA." 45 | print("Initializing Distributed") 46 | 47 | # Set cuda device so everything is done on the right GPU. 48 | torch.cuda.set_device(rank % torch.cuda.device_count()) 49 | 50 | # Initialize distributed communication 51 | dist.init_process_group(dist_backend, init_method=dist_url, 52 | world_size=num_gpus, rank=rank, 53 | group_name=group_name) 54 | 55 | def _flatten_dense_tensors(tensors): 56 | """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of 57 | same dense type. 58 | Since inputs are dense, the resulting tensor will be a concatenated 1D 59 | buffer. Element-wise operation on this buffer will be equivalent to 60 | operating individually. 61 | Arguments: 62 | tensors (Iterable[Tensor]): dense tensors to flatten. 63 | Returns: 64 | A contiguous 1D buffer containing input tensors. 65 | """ 66 | if len(tensors) == 1: 67 | return tensors[0].contiguous().view(-1) 68 | flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) 69 | return flat 70 | 71 | def _unflatten_dense_tensors(flat, tensors): 72 | """View a flat buffer using the sizes of tensors. Assume that tensors are of 73 | same dense type, and that flat is given by _flatten_dense_tensors. 74 | Arguments: 75 | flat (Tensor): flattened dense tensors to unflatten. 76 | tensors (Iterable[Tensor]): dense tensors whose sizes will be used to 77 | unflatten flat. 78 | Returns: 79 | Unflattened dense tensors with sizes same as tensors and values from 80 | flat. 81 | """ 82 | outputs = [] 83 | offset = 0 84 | for tensor in tensors: 85 | numel = tensor.numel() 86 | outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) 87 | offset += numel 88 | return tuple(outputs) 89 | 90 | def apply_gradient_allreduce(module): 91 | """ 92 | Modifies existing model to do gradient allreduce, but doesn't change class 93 | so you don't need "module" 94 | """ 95 | if not hasattr(dist, '_backend'): 96 | module.warn_on_half = True 97 | else: 98 | module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 99 | 100 | for p in module.state_dict().values(): 101 | if not torch.is_tensor(p): 102 | continue 103 | dist.broadcast(p, 0) 104 | 105 | def allreduce_params(): 106 | if(module.needs_reduction): 107 | module.needs_reduction = False 108 | buckets = {} 109 | for param in module.parameters(): 110 | if param.requires_grad and param.grad is not None: 111 | tp = type(param.data) 112 | if tp not in buckets: 113 | buckets[tp] = [] 114 | buckets[tp].append(param) 115 | if module.warn_on_half: 116 | if torch.cuda.HalfTensor in buckets: 117 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 118 | " It is recommended to use the NCCL backend in this case. This currently requires" + 119 | "PyTorch built from top of tree master.") 120 | module.warn_on_half = False 121 | 122 | for tp in buckets: 123 | bucket = buckets[tp] 124 | grads = [param.grad.data for param in bucket] 125 | coalesced = _flatten_dense_tensors(grads) 126 | dist.all_reduce(coalesced) 127 | coalesced /= dist.get_world_size() 128 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 129 | buf.copy_(synced) 130 | 131 | for param in list(module.parameters()): 132 | def allreduce_hook(*unused): 133 | Variable._execution_engine.queue_callback(allreduce_params) 134 | if param.requires_grad: 135 | param.register_hook(allreduce_hook) 136 | dir(param) 137 | 138 | def set_needs_reduction(self, input, output): 139 | self.needs_reduction = True 140 | 141 | module.register_forward_hook(set_needs_reduction) 142 | return module 143 | 144 | 145 | def main(config, stdout_dir, args_str): 146 | args_list = ['train.py'] 147 | args_list += args_str.split(' ') if len(args_str) > 0 else [] 148 | 149 | args_list.append('--config={}'.format(config)) 150 | 151 | num_gpus = torch.cuda.device_count() 152 | args_list.append('--num_gpus={}'.format(num_gpus)) 153 | args_list.append("--group_name=group_{}".format(time.strftime("%Y_%m_%d-%H%M%S"))) 154 | 155 | if not os.path.isdir(stdout_dir): 156 | os.makedirs(stdout_dir) 157 | os.chmod(stdout_dir, 0o775) 158 | 159 | workers = [] 160 | 161 | for i in range(num_gpus): 162 | args_list[-2] = '--rank={}'.format(i) 163 | stdout = None if i == 0 else open( 164 | os.path.join(stdout_dir, "GPU_{}.log".format(i)), "w") 165 | print(args_list) 166 | p = subprocess.Popen([str(sys.executable)]+args_list, stdout=stdout) 167 | workers.append(p) 168 | 169 | for p in workers: 170 | p.wait() 171 | 172 | 173 | if __name__ == '__main__': 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument('-c', '--config', type=str, required=True, 176 | help='JSON file for configuration') 177 | parser.add_argument('-s', '--stdout_dir', type=str, default=".", 178 | help='directory to save stoud logs') 179 | parser.add_argument( 180 | '-a', '--args_str', type=str, default='', 181 | help='double quoted string with space separated key value pairs') 182 | 183 | args = parser.parse_args() 184 | main(args.config, args.stdout_dir, args.args_str) 185 | -------------------------------------------------------------------------------- /commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from librosa.filters import mel as librosa_mel_fn 8 | from audio_processing import dynamic_range_compression 9 | from audio_processing import dynamic_range_decompression 10 | from stft import STFT 11 | 12 | 13 | def intersperse(lst, item): 14 | result = [item] * (len(lst) * 2 + 1) 15 | result[1::2] = lst 16 | return result 17 | 18 | 19 | def mle_loss(z, m, logs, mask): 20 | l = torch.sum(logs) + 0.5 * torch.sum(torch.exp(-2 * logs) * ((z - m)**2)) # neg normal likelihood w/o the constant term 21 | l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes 22 | l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term 23 | return l 24 | 25 | 26 | def duration_loss(logw, logw_, lengths): 27 | l = torch.sum((logw - logw_)**2) / torch.sum(lengths) 28 | return l 29 | 30 | 31 | @torch.jit.script 32 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 33 | n_channels_int = n_channels[0] 34 | in_act = input_a + input_b 35 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 36 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 37 | acts = t_act * s_act 38 | return acts 39 | 40 | 41 | def convert_pad_shape(pad_shape): 42 | l = pad_shape[::-1] 43 | pad_shape = [item for sublist in l for item in sublist] 44 | return pad_shape 45 | 46 | 47 | def shift_1d(x): 48 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 49 | return x 50 | 51 | 52 | def sequence_mask(length, max_length=None): 53 | if max_length is None: 54 | max_length = length.max() 55 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 56 | return x.unsqueeze(0) < length.unsqueeze(1) 57 | 58 | 59 | def maximum_path(value, mask, max_neg_val=-np.inf): 60 | """ Numpy-friendly version. It's about 4 times faster than torch version. 61 | value: [b, t_x, t_y] 62 | mask: [b, t_x, t_y] 63 | """ 64 | value = value * mask 65 | 66 | device = value.device 67 | dtype = value.dtype 68 | value = value.cpu().detach().numpy() 69 | mask = mask.cpu().detach().numpy().astype(np.bool) 70 | 71 | b, t_x, t_y = value.shape 72 | direction = np.zeros(value.shape, dtype=np.int64) 73 | v = np.zeros((b, t_x), dtype=np.float32) 74 | x_range = np.arange(t_x, dtype=np.float32).reshape(1,-1) 75 | for j in range(t_y): 76 | v0 = np.pad(v, [[0,0],[1,0]], mode="constant", constant_values=max_neg_val)[:, :-1] 77 | v1 = v 78 | max_mask = (v1 >= v0) 79 | v_max = np.where(max_mask, v1, v0) 80 | direction[:, :, j] = max_mask 81 | 82 | index_mask = (x_range <= j) 83 | v = np.where(index_mask, v_max + value[:, :, j], max_neg_val) 84 | direction = np.where(mask, direction, 1) 85 | 86 | path = np.zeros(value.shape, dtype=np.float32) 87 | index = mask[:, :, 0].sum(1).astype(np.int64) - 1 88 | index_range = np.arange(b) 89 | for j in reversed(range(t_y)): 90 | path[index_range, index, j] = 1 91 | index = index + direction[index_range, index, j] - 1 92 | path = path * mask.astype(np.float32) 93 | path = torch.from_numpy(path).to(device=device, dtype=dtype) 94 | return path 95 | 96 | 97 | def generate_path(duration, mask): 98 | """ 99 | duration: [b, t_x] 100 | mask: [b, t_x, t_y] 101 | """ 102 | device = duration.device 103 | 104 | b, t_x, t_y = mask.shape 105 | cum_duration = torch.cumsum(duration, 1) 106 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) 107 | 108 | cum_duration_flat = cum_duration.view(b * t_x) 109 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 110 | path = path.view(b, t_x, t_y) 111 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:,:-1] 112 | path = path * mask 113 | return path 114 | 115 | 116 | class Adam(): 117 | def __init__(self, scheduler, dim_model, lr, warmup_steps=4000): 118 | self.scheduler = scheduler 119 | self.dim_model = dim_model 120 | self.warmup_steps = warmup_steps 121 | self.lr = lr 122 | 123 | self.step_num = 1 124 | self.cur_lr = lr * self._get_lr_scale() 125 | 126 | def _get_lr_scale(self): 127 | if self.scheduler == "noam": 128 | return np.power(self.dim_model, -0.5) * np.min([np.power(self.step_num, -0.5), self.step_num * np.power(self.warmup_steps, -1.5)]) 129 | else: 130 | return 1 131 | 132 | def _update_learning_rate(self): 133 | self.step_num += 1 134 | if self.scheduler == "noam": 135 | self.cur_lr = self.lr * self._get_lr_scale() 136 | for param_group in self._optim.param_groups: 137 | param_group['lr'] = self.cur_lr 138 | 139 | def set_optimizer(self, optimizer): 140 | self._optim = optimizer 141 | 142 | def get_lr(self): 143 | return self.cur_lr 144 | 145 | def step(self): 146 | self._optim.step() 147 | self._update_learning_rate() 148 | 149 | def zero_grad(self): 150 | self._optim.zero_grad() 151 | 152 | def load_state_dict(self, d): 153 | self._optim.load_state_dict(d) 154 | 155 | def state_dict(self): 156 | return self._optim.state_dict() 157 | 158 | def param_groups(self): 159 | return self._optim.param_groups() 160 | 161 | 162 | 163 | class TacotronSTFT(nn.Module): 164 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 165 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 166 | mel_fmax=8000.0): 167 | super(TacotronSTFT, self).__init__() 168 | self.n_mel_channels = n_mel_channels 169 | self.sampling_rate = sampling_rate 170 | self.stft_fn = STFT(filter_length, hop_length, win_length) 171 | mel_basis = librosa_mel_fn( 172 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 173 | mel_basis = torch.from_numpy(mel_basis).float() 174 | self.register_buffer('mel_basis', mel_basis) 175 | 176 | def spectral_normalize(self, magnitudes): 177 | output = dynamic_range_compression(magnitudes) 178 | return output 179 | 180 | def spectral_de_normalize(self, magnitudes): 181 | output = dynamic_range_decompression(magnitudes) 182 | return output 183 | 184 | def mel_spectrogram(self, y): 185 | """Computes mel-spectrograms from a batch of waves 186 | PARAMS 187 | ------ 188 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 189 | 190 | RETURNS 191 | ------- 192 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 193 | """ 194 | assert(torch.min(y.data) >= -1) 195 | assert(torch.max(y.data) <= 1) 196 | 197 | magnitudes, phases = self.stft_fn.transform(y) 198 | magnitudes = magnitudes.data 199 | mel_output = torch.matmul(self.mel_basis, magnitudes) 200 | mel_output = self.spectral_normalize(mel_output) 201 | return mel_output 202 | 203 | 204 | def clip_grad_value_(parameters, clip_value, norm_type=2): 205 | if isinstance(parameters, torch.Tensor): 206 | parameters = [parameters] 207 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 208 | norm_type = float(norm_type) 209 | clip_value = float(clip_value) 210 | 211 | total_norm = 0 212 | for p in parameters: 213 | param_norm = p.grad.data.norm(norm_type) 214 | total_norm += param_norm.item() ** norm_type 215 | 216 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 217 | total_norm = total_norm ** (1. / norm_type) 218 | return total_norm 219 | 220 | 221 | def squeeze(x, x_mask=None, n_sqz=2): 222 | b, c, t = x.size() 223 | 224 | t = (t // n_sqz) * n_sqz 225 | x = x[:,:,:t] 226 | x_sqz = x.view(b, c, t//n_sqz, n_sqz) 227 | x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c*n_sqz, t//n_sqz) 228 | 229 | if x_mask is not None: 230 | x_mask = x_mask[:,:,n_sqz-1::n_sqz] 231 | else: 232 | x_mask = torch.ones(b, 1, t//n_sqz).to(device=x.device, dtype=x.dtype) 233 | return x_sqz * x_mask, x_mask 234 | 235 | 236 | def unsqueeze(x, x_mask=None, n_sqz=2): 237 | b, c, t = x.size() 238 | 239 | x_unsqz = x.view(b, n_sqz, c//n_sqz, t) 240 | x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c//n_sqz, t*n_sqz) 241 | 242 | if x_mask is not None: 243 | x_mask = x_mask.unsqueeze(-1).repeat(1,1,1,n_sqz).view(b, 1, t*n_sqz) 244 | else: 245 | x_mask = torch.ones(b, 1, t*n_sqz).to(device=x.device, dtype=x.dtype) 246 | return x_unsqz * x_mask, x_mask 247 | 248 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import math 5 | import torch 6 | from torch import nn, optim 7 | from torch.nn import functional as F 8 | from torch.utils.data import DataLoader 9 | from torch.utils.tensorboard import SummaryWriter 10 | import torch.multiprocessing as mp 11 | import torch.distributed as dist 12 | 13 | from data_utils import TextMelLoader, TextMelCollate 14 | import models 15 | import commons 16 | import utils 17 | from text.symbols import symbols 18 | import horovod.torch as hvd 19 | 20 | hvd.init() 21 | torch.cuda.set_device(hvd.local_rank()) 22 | 23 | global_step = 0 24 | 25 | 26 | def main(): 27 | """Assume Single Node Multi GPUs Training Only""" 28 | assert torch.cuda.is_available(), "CPU training is not allowed." 29 | 30 | n_gpus = torch.cuda.device_count() 31 | hps = utils.get_hparams() 32 | train_and_eval(n_gpus, hps) 33 | 34 | 35 | def train_and_eval(n_gpus, hps): 36 | global global_step 37 | if hvd.local_rank() == 0: 38 | logger = utils.get_logger(hps.model_dir) 39 | logger.info(hps) 40 | utils.check_git_hash(hps.model_dir) 41 | writer = SummaryWriter(log_dir=hps.model_dir) 42 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) 43 | torch.manual_seed(hps.train.seed) 44 | 45 | train_dataset = TextMelLoader(hps.data.training_files, hps.data) 46 | train_sampler = torch.utils.data.distributed.DistributedSampler( 47 | train_dataset, 48 | num_replicas=hvd.size(), 49 | rank=hvd.rank(), 50 | shuffle=True) 51 | collate_fn = TextMelCollate(1) 52 | train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, 53 | batch_size=hps.train.batch_size, pin_memory=True, 54 | drop_last=True, collate_fn=collate_fn, sampler=train_sampler) 55 | if hvd.local_rank() == 0: 56 | val_dataset = TextMelLoader(hps.data.validation_files, hps.data) 57 | val_loader = DataLoader(val_dataset, num_workers=8, shuffle=False, 58 | batch_size=hps.train.batch_size, pin_memory=True, 59 | drop_last=True, collate_fn=collate_fn) 60 | 61 | generator = models.DiffusionGenerator( 62 | n_vocab=len(symbols) + getattr(hps.data, "add_blank", False), 63 | enc_out_channels=hps.data.n_mel_channels, 64 | **hps.model).cuda(hvd.local_rank()) 65 | 66 | optimizer_g = commons.Adam(scheduler=hps.train.scheduler, dim_model=hps.model.hidden_channels, lr=hps.train.learning_rate) 67 | t_optimizer = torch.optim.Adam(generator.parameters(), lr=optimizer_g.get_lr(), betas=hps.train.betas, eps=hps.train.eps) 68 | t_optimizer = hvd.DistributedOptimizer(t_optimizer, named_parameters=generator.named_parameters()) 69 | hvd.broadcast_parameters(generator.state_dict(), root_rank=0) 70 | optimizer_g.set_optimizer(t_optimizer) 71 | 72 | if hps.train.fp16_run: 73 | generator, optimizer_g._optim = amp.initialize(generator, optimizer_g._optim, opt_level="O1") 74 | epoch_str = 1 75 | global_step = 0 76 | try: 77 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), generator, optimizer_g) 78 | epoch_str += 1 79 | optimizer_g.step_num = (epoch_str - 1) * len(train_loader) 80 | optimizer_g._update_learning_rate() 81 | global_step = (epoch_str - 1) * len(train_loader) 82 | except: 83 | if hps.train.ddi and os.path.isfile(os.path.join(hps.model_dir, "ddi_G.pth")): 84 | _ = utils.load_checkpoint(os.path.join(hps.model_dir, "ddi_G.pth"), generator, optimizer_g) 85 | 86 | for epoch in range(epoch_str, hps.train.epochs + 1): 87 | if hvd.local_rank()==0: 88 | train(hvd.local_rank(), epoch, hps, generator, optimizer_g, train_loader, logger, writer) 89 | evaluate(hvd.local_rank(), epoch, hps, generator, optimizer_g, val_loader, logger, writer_eval) 90 | utils.save_checkpoint(generator, optimizer_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(epoch))) 91 | else: 92 | train(hvd.local_rank(), epoch, hps, generator, optimizer_g, train_loader, None, None) 93 | 94 | 95 | def train(rank, epoch, hps, generator, optimizer_g, train_loader, logger, writer): 96 | train_loader.sampler.set_epoch(epoch) 97 | global global_step 98 | 99 | generator.train() 100 | for batch_idx, (x, x_lengths, y, y_lengths) in enumerate(train_loader): 101 | x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(rank, non_blocking=True) 102 | y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True) 103 | 104 | # Train Generator 105 | optimizer_g.zero_grad() 106 | 107 | grad_loss, (z_m, z_logs, z_mask), (attn, logw, logw_) = generator(x, x_lengths, y, y_lengths, gen=False) 108 | l_mle = commons.mle_loss(y, z_m, z_logs, z_mask) # z_logs is not used because we use N(mu, I) as the X_t 109 | l_length = commons.duration_loss(logw, logw_, x_lengths) 110 | 111 | loss_gs = [grad_loss, l_mle, l_length] 112 | loss_g = sum(loss_gs) 113 | 114 | if hps.train.fp16_run: 115 | with amp.scale_loss(loss_g, optimizer_g._optim) as scaled_loss: 116 | scaled_loss.backward() 117 | grad_norm = commons.clip_grad_value_(amp.master_params(optimizer_g._optim), 5) 118 | else: 119 | loss_g.backward() 120 | grad_norm = commons.clip_grad_value_(generator.parameters(), 5) 121 | optimizer_g.step() 122 | 123 | if rank==0: 124 | if batch_idx % hps.train.log_interval == 0: 125 | logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 126 | epoch, batch_idx * len(x), len(train_loader.dataset), 127 | 100. * batch_idx / len(train_loader), 128 | loss_g.item())) 129 | logger.info([x.item() for x in loss_gs] + [global_step, optimizer_g.get_lr()]) 130 | 131 | if batch_idx % (hps.train.log_interval * 1000) == 0: 132 | (y_gen, *_), *_ = generator(x[:1], x_lengths[:1], gen=True) 133 | scalar_dict = {"loss/g/total": loss_g, "learning_rate": optimizer_g.get_lr(), "grad_norm": grad_norm} 134 | scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(loss_gs)}) 135 | utils.summarize( 136 | writer=writer, 137 | global_step=global_step, 138 | images={"y_org": utils.plot_spectrogram_to_numpy(y[0].data.cpu().numpy()), 139 | "y_gen": utils.plot_spectrogram_to_numpy(y_gen[0].data.cpu().numpy()), 140 | "attn": utils.plot_alignment_to_numpy(attn[0,0].data.cpu().numpy()), 141 | }, 142 | scalars=scalar_dict) 143 | 144 | global_step += 1 145 | 146 | if rank == 0: 147 | logger.info('====> Epoch: {}'.format(epoch)) 148 | 149 | 150 | def evaluate(rank, epoch, hps, generator, optimizer_g, val_loader, logger, writer_eval): 151 | if rank == 0: 152 | global global_step 153 | generator.eval() 154 | losses_tot = [] 155 | with torch.no_grad(): 156 | for batch_idx, (x, x_lengths, y, y_lengths) in enumerate(val_loader): 157 | x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(rank, non_blocking=True) 158 | y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True) 159 | 160 | 161 | grad_loss, (z_m, z_logs, z_mask), (attn, logw, logw_) = generator(x, x_lengths, y, y_lengths, gen=False) 162 | l_mle = commons.mle_loss(y, z_m, torch.ones_like(z_m), z_mask) # z_logs is not used because we use N(mu, I) as the X_t 163 | l_length = commons.duration_loss(logw, logw_, x_lengths) 164 | 165 | loss_gs = [grad_loss, l_mle, l_length] 166 | loss_g = sum(loss_gs) 167 | 168 | if batch_idx == 0: 169 | losses_tot = loss_gs 170 | else: 171 | losses_tot = [x + y for (x, y) in zip(losses_tot, loss_gs)] 172 | 173 | if batch_idx % hps.train.log_interval == 0: 174 | logger.info('Eval Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 175 | epoch, batch_idx * len(x), len(val_loader.dataset), 176 | 100. * batch_idx / len(val_loader), 177 | loss_g.item())) 178 | logger.info([x.item() for x in loss_gs]) 179 | 180 | 181 | losses_tot = [x/len(val_loader) for x in losses_tot] 182 | loss_tot = sum(losses_tot) 183 | scalar_dict = {"loss/g/total": loss_tot} 184 | scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_tot)}) 185 | utils.summarize( 186 | writer=writer_eval, 187 | global_step=global_step, 188 | scalars=scalar_dict) 189 | logger.info('====> Epoch: {}'.format(epoch)) 190 | 191 | 192 | if __name__ == "__main__": 193 | main() 194 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | from scipy.io.wavfile import read 10 | import torch 11 | 12 | MATPLOTLIB_FLAG = False 13 | 14 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 15 | logger = logging 16 | 17 | def load_checkpoint(checkpoint_path, model, optimizer=None): 18 | assert os.path.isfile(checkpoint_path) 19 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 20 | iteration = 1 21 | if 'iteration' in checkpoint_dict.keys(): 22 | iteration = checkpoint_dict['iteration'] 23 | if 'learning_rate' in checkpoint_dict.keys(): 24 | learning_rate = checkpoint_dict['learning_rate'] 25 | if optimizer is not None and 'optimizer' in checkpoint_dict.keys(): 26 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 27 | saved_state_dict = checkpoint_dict['model'] 28 | if hasattr(model, 'module'): 29 | state_dict = model.module.state_dict() 30 | else: 31 | state_dict = model.state_dict() 32 | new_state_dict= {} 33 | for k, v in state_dict.items(): 34 | try: 35 | new_state_dict[k] = saved_state_dict[k] 36 | except: 37 | logger.info("%s is not in the checkpoint" % k) 38 | new_state_dict[k] = v 39 | if hasattr(model, 'module'): 40 | model.module.load_state_dict(new_state_dict) 41 | else: 42 | model.load_state_dict(new_state_dict) 43 | logger.info("Loaded checkpoint '{}' (iteration {})" .format( 44 | checkpoint_path, iteration)) 45 | return model, optimizer, learning_rate, iteration 46 | 47 | 48 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 49 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 50 | iteration, checkpoint_path)) 51 | if hasattr(model, 'module'): 52 | state_dict = model.module.state_dict() 53 | else: 54 | state_dict = model.state_dict() 55 | torch.save({'model': state_dict, 56 | 'iteration': iteration, 57 | 'optimizer': optimizer.state_dict(), 58 | 'learning_rate': learning_rate}, checkpoint_path) 59 | 60 | 61 | def summarize(writer, global_step, scalars={}, histograms={}, images={}): 62 | for k, v in scalars.items(): 63 | writer.add_scalar(k, v, global_step) 64 | for k, v in histograms.items(): 65 | writer.add_histogram(k, v, global_step) 66 | for k, v in images.items(): 67 | writer.add_image(k, v, global_step, dataformats='HWC') 68 | 69 | 70 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 71 | f_list = glob.glob(os.path.join(dir_path, regex)) 72 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 73 | x = f_list[-1] 74 | print(x) 75 | return x 76 | 77 | 78 | def plot_spectrogram_to_numpy(spectrogram): 79 | global MATPLOTLIB_FLAG 80 | if not MATPLOTLIB_FLAG: 81 | import matplotlib 82 | matplotlib.use("Agg") 83 | MATPLOTLIB_FLAG = True 84 | mpl_logger = logging.getLogger('matplotlib') 85 | mpl_logger.setLevel(logging.WARNING) 86 | import matplotlib.pylab as plt 87 | import numpy as np 88 | 89 | fig, ax = plt.subplots() 90 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 91 | interpolation='none') 92 | plt.colorbar(im, ax=ax) 93 | plt.xlabel("Frames") 94 | plt.ylabel("Channels") 95 | plt.tight_layout() 96 | 97 | fig.canvas.draw() 98 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 99 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 100 | plt.close() 101 | return data 102 | 103 | 104 | def plot_alignment_to_numpy(alignment, info=None): 105 | global MATPLOTLIB_FLAG 106 | if not MATPLOTLIB_FLAG: 107 | import matplotlib 108 | matplotlib.use("Agg") 109 | MATPLOTLIB_FLAG = True 110 | mpl_logger = logging.getLogger('matplotlib') 111 | mpl_logger.setLevel(logging.WARNING) 112 | import matplotlib.pylab as plt 113 | import numpy as np 114 | 115 | fig, ax = plt.subplots(figsize=(6, 4)) 116 | im = ax.imshow(alignment, aspect='auto', origin='lower', 117 | interpolation='none') 118 | fig.colorbar(im, ax=ax) 119 | xlabel = 'Decoder timestep' 120 | if info is not None: 121 | xlabel += '\n\n' + info 122 | plt.xlabel(xlabel) 123 | plt.ylabel('Encoder timestep') 124 | plt.tight_layout() 125 | 126 | fig.canvas.draw() 127 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 128 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 129 | plt.close() 130 | return data 131 | 132 | 133 | def load_wav_to_torch(full_path): 134 | sampling_rate, data = read(full_path) 135 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 136 | 137 | 138 | def load_filepaths_and_text(filename, split="|"): 139 | with open(filename, encoding='utf-8') as f: 140 | filepaths_and_text = [line.strip().split(split) for line in f] 141 | return filepaths_and_text 142 | 143 | 144 | def get_hparams(init=True): 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument('-c', '--config', type=str, default="./configs/base.json", 147 | help='JSON file for configuration') 148 | parser.add_argument('-l', '--logdir', type=str, required=True) 149 | parser.add_argument('-m', '--model', type=str, required=True, 150 | help='Model name') 151 | 152 | args = parser.parse_args() 153 | model_dir = os.path.join(args.logdir, args.model) 154 | 155 | if not os.path.exists(model_dir): 156 | os.makedirs(model_dir, exist_ok=True) 157 | 158 | config_path = args.config 159 | config_save_path = os.path.join(model_dir, "config.json") 160 | if init: 161 | with open(config_path, "r") as f: 162 | data = f.read() 163 | with open(config_save_path, "w") as f: 164 | f.write(data) 165 | else: 166 | with open(config_save_path, "r") as f: 167 | data = f.read() 168 | config = json.loads(data) 169 | 170 | hparams = HParams(**config) 171 | hparams.model_dir = model_dir 172 | return hparams 173 | 174 | 175 | def get_hparams_from_dir(model_dir): 176 | config_save_path = os.path.join(model_dir, "config.json") 177 | with open(config_save_path, "r") as f: 178 | data = f.read() 179 | config = json.loads(data) 180 | 181 | hparams =HParams(**config) 182 | hparams.model_dir = model_dir 183 | return hparams 184 | 185 | 186 | def get_hparams_from_file(config_path): 187 | with open(config_path, "r") as f: 188 | data = f.read() 189 | config = json.loads(data) 190 | 191 | hparams =HParams(**config) 192 | return hparams 193 | 194 | 195 | def check_git_hash(model_dir): 196 | source_dir = os.path.dirname(os.path.realpath(__file__)) 197 | if not os.path.exists(os.path.join(source_dir, ".git")): 198 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 199 | source_dir 200 | )) 201 | return 202 | 203 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 204 | 205 | path = os.path.join(model_dir, "githash") 206 | if os.path.exists(path): 207 | saved_hash = open(path).read() 208 | if saved_hash != cur_hash: 209 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 210 | saved_hash[:8], cur_hash[:8])) 211 | else: 212 | open(path, "w").write(cur_hash) 213 | 214 | 215 | def get_logger(model_dir, filename="train.log"): 216 | global logger 217 | logger = logging.getLogger(os.path.basename(model_dir)) 218 | logger.setLevel(logging.DEBUG) 219 | 220 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 221 | if not os.path.exists(model_dir): 222 | os.makedirs(model_dir) 223 | h = logging.FileHandler(os.path.join(model_dir, filename)) 224 | h.setLevel(logging.DEBUG) 225 | h.setFormatter(formatter) 226 | logger.addHandler(h) 227 | return logger 228 | 229 | 230 | class HParams(): 231 | def __init__(self, **kwargs): 232 | for k, v in kwargs.items(): 233 | if type(v) == dict: 234 | v = HParams(**v) 235 | self[k] = v 236 | 237 | def keys(self): 238 | return self.__dict__.keys() 239 | 240 | def items(self): 241 | return self.__dict__.items() 242 | 243 | def values(self): 244 | return self.__dict__.values() 245 | 246 | def __len__(self): 247 | return len(self.__dict__) 248 | 249 | def __getitem__(self, key): 250 | return getattr(self, key) 251 | 252 | def __setitem__(self, key, value): 253 | return setattr(self, key, value) 254 | 255 | def __contains__(self, key): 256 | return key in self.__dict__ 257 | 258 | def __repr__(self): 259 | return self.__dict__.__repr__() 260 | -------------------------------------------------------------------------------- /waveglow/train.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import argparse 28 | import json 29 | import os 30 | import torch 31 | 32 | #=====START: ADDED FOR DISTRIBUTED====== 33 | from distributed import init_distributed, apply_gradient_allreduce, reduce_tensor 34 | from torch.utils.data.distributed import DistributedSampler 35 | #=====END: ADDED FOR DISTRIBUTED====== 36 | 37 | from torch.utils.data import DataLoader 38 | from glow import WaveGlow, WaveGlowLoss 39 | from mel2samp import Mel2Samp 40 | 41 | def load_checkpoint(checkpoint_path, model, optimizer): 42 | assert os.path.isfile(checkpoint_path) 43 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 44 | iteration = checkpoint_dict['iteration'] 45 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 46 | model_for_loading = checkpoint_dict['model'] 47 | model.load_state_dict(model_for_loading.state_dict()) 48 | print("Loaded checkpoint '{}' (iteration {})" .format( 49 | checkpoint_path, iteration)) 50 | return model, optimizer, iteration 51 | 52 | def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): 53 | print("Saving model and optimizer state at iteration {} to {}".format( 54 | iteration, filepath)) 55 | model_for_saving = WaveGlow(**waveglow_config).cuda() 56 | model_for_saving.load_state_dict(model.state_dict()) 57 | torch.save({'model': model_for_saving, 58 | 'iteration': iteration, 59 | 'optimizer': optimizer.state_dict(), 60 | 'learning_rate': learning_rate}, filepath) 61 | 62 | def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate, 63 | sigma, iters_per_checkpoint, batch_size, seed, fp16_run, 64 | checkpoint_path, with_tensorboard): 65 | torch.manual_seed(seed) 66 | torch.cuda.manual_seed(seed) 67 | #=====START: ADDED FOR DISTRIBUTED====== 68 | if num_gpus > 1: 69 | init_distributed(rank, num_gpus, group_name, **dist_config) 70 | #=====END: ADDED FOR DISTRIBUTED====== 71 | 72 | criterion = WaveGlowLoss(sigma) 73 | model = WaveGlow(**waveglow_config).cuda() 74 | 75 | #=====START: ADDED FOR DISTRIBUTED====== 76 | if num_gpus > 1: 77 | model = apply_gradient_allreduce(model) 78 | #=====END: ADDED FOR DISTRIBUTED====== 79 | 80 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 81 | 82 | if fp16_run: 83 | from apex import amp 84 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 85 | 86 | # Load checkpoint if one exists 87 | iteration = 0 88 | if checkpoint_path != "": 89 | model, optimizer, iteration = load_checkpoint(checkpoint_path, model, 90 | optimizer) 91 | iteration += 1 # next iteration is iteration + 1 92 | 93 | trainset = Mel2Samp(**data_config) 94 | # =====START: ADDED FOR DISTRIBUTED====== 95 | train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None 96 | # =====END: ADDED FOR DISTRIBUTED====== 97 | train_loader = DataLoader(trainset, num_workers=1, shuffle=False, 98 | sampler=train_sampler, 99 | batch_size=batch_size, 100 | pin_memory=False, 101 | drop_last=True) 102 | 103 | # Get shared output_directory ready 104 | if rank == 0: 105 | if not os.path.isdir(output_directory): 106 | os.makedirs(output_directory) 107 | os.chmod(output_directory, 0o775) 108 | print("output directory", output_directory) 109 | 110 | if with_tensorboard and rank == 0: 111 | from tensorboardX import SummaryWriter 112 | logger = SummaryWriter(os.path.join(output_directory, 'logs')) 113 | 114 | model.train() 115 | epoch_offset = max(0, int(iteration / len(train_loader))) 116 | # ================ MAIN TRAINNIG LOOP! =================== 117 | for epoch in range(epoch_offset, epochs): 118 | print("Epoch: {}".format(epoch)) 119 | for i, batch in enumerate(train_loader): 120 | model.zero_grad() 121 | 122 | mel, audio = batch 123 | mel = torch.autograd.Variable(mel.cuda()) 124 | audio = torch.autograd.Variable(audio.cuda()) 125 | outputs = model((mel, audio)) 126 | 127 | loss = criterion(outputs) 128 | if num_gpus > 1: 129 | reduced_loss = reduce_tensor(loss.data, num_gpus).item() 130 | else: 131 | reduced_loss = loss.item() 132 | 133 | if fp16_run: 134 | with amp.scale_loss(loss, optimizer) as scaled_loss: 135 | scaled_loss.backward() 136 | else: 137 | loss.backward() 138 | 139 | optimizer.step() 140 | 141 | print("{}:\t{:.9f}".format(iteration, reduced_loss)) 142 | if with_tensorboard and rank == 0: 143 | logger.add_scalar('training_loss', reduced_loss, i + len(train_loader) * epoch) 144 | 145 | if (iteration % iters_per_checkpoint == 0): 146 | if rank == 0: 147 | checkpoint_path = "{}/waveglow_{}".format( 148 | output_directory, iteration) 149 | save_checkpoint(model, optimizer, learning_rate, iteration, 150 | checkpoint_path) 151 | 152 | iteration += 1 153 | 154 | if __name__ == "__main__": 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument('-c', '--config', type=str, 157 | help='JSON file for configuration') 158 | parser.add_argument('-r', '--rank', type=int, default=0, 159 | help='rank of process for distributed') 160 | parser.add_argument('-g', '--group_name', type=str, default='', 161 | help='name of group for distributed') 162 | args = parser.parse_args() 163 | 164 | # Parse configs. Globals nicer in this case 165 | with open(args.config) as f: 166 | data = f.read() 167 | config = json.loads(data) 168 | train_config = config["train_config"] 169 | global data_config 170 | data_config = config["data_config"] 171 | global dist_config 172 | dist_config = config["dist_config"] 173 | global waveglow_config 174 | waveglow_config = config["waveglow_config"] 175 | 176 | num_gpus = torch.cuda.device_count() 177 | if num_gpus > 1: 178 | if args.group_name == '': 179 | print("WARNING: Multiple GPUs detected but no distributed group set") 180 | print("Only running 1 GPU. Use distributed.py for multiple GPUs") 181 | num_gpus = 1 182 | 183 | if num_gpus == 1 and args.rank != 0: 184 | raise Exception("Doing single GPU training on rank > 0") 185 | 186 | torch.backends.cudnn.enabled = True 187 | torch.backends.cudnn.benchmark = False 188 | train(num_gpus, args.rank, args.group_name, **train_config) 189 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import scipy 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | import commons 10 | 11 | 12 | class LayerNorm(nn.Module): 13 | def __init__(self, channels, eps=1e-4): 14 | super().__init__() 15 | self.channels = channels 16 | self.eps = eps 17 | 18 | self.gamma = nn.Parameter(torch.ones(channels)) 19 | self.beta = nn.Parameter(torch.zeros(channels)) 20 | 21 | def forward(self, x): 22 | n_dims = len(x.shape) 23 | mean = torch.mean(x, 1, keepdim=True) 24 | variance = torch.mean((x -mean)**2, 1, keepdim=True) 25 | 26 | x = (x - mean) * torch.rsqrt(variance + self.eps) 27 | 28 | shape = [1, -1] + [1] * (n_dims - 2) 29 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 30 | return x 31 | 32 | 33 | class ConvReluNorm(nn.Module): 34 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): 35 | super().__init__() 36 | self.in_channels = in_channels 37 | self.hidden_channels = hidden_channels 38 | self.out_channels = out_channels 39 | self.kernel_size = kernel_size 40 | self.n_layers = n_layers 41 | self.p_dropout = p_dropout 42 | assert n_layers > 1, "Number of layers should be larger than 0." 43 | 44 | self.conv_layers = nn.ModuleList() 45 | self.norm_layers = nn.ModuleList() 46 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) 47 | self.norm_layers.append(LayerNorm(hidden_channels)) 48 | self.relu_drop = nn.Sequential( 49 | nn.ReLU(), 50 | nn.Dropout(p_dropout)) 51 | for _ in range(n_layers-1): 52 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) 53 | self.norm_layers.append(LayerNorm(hidden_channels)) 54 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 55 | self.proj.weight.data.zero_() 56 | self.proj.bias.data.zero_() 57 | 58 | def forward(self, x, x_mask): 59 | x_org = x 60 | for i in range(self.n_layers): 61 | x = self.conv_layers[i](x * x_mask) 62 | x = self.norm_layers[i](x) 63 | x = self.relu_drop(x) 64 | x = x_org + self.proj(x) 65 | return x * x_mask 66 | 67 | 68 | class WN(torch.nn.Module): 69 | def __init__(self, in_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 70 | super(WN, self).__init__() 71 | assert(kernel_size % 2 == 1) 72 | assert(hidden_channels % 2 == 0) 73 | self.in_channels = in_channels 74 | self.hidden_channels =hidden_channels 75 | self.kernel_size = kernel_size, 76 | self.dilation_rate = dilation_rate 77 | self.n_layers = n_layers 78 | self.gin_channels = gin_channels 79 | self.p_dropout = p_dropout 80 | 81 | self.in_layers = torch.nn.ModuleList() 82 | self.res_skip_layers = torch.nn.ModuleList() 83 | self.drop = nn.Dropout(p_dropout) 84 | 85 | if gin_channels != 0: 86 | cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) 87 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 88 | 89 | for i in range(n_layers): 90 | dilation = dilation_rate ** i 91 | padding = int((kernel_size * dilation - dilation) / 2) 92 | in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, 93 | dilation=dilation, padding=padding) 94 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 95 | self.in_layers.append(in_layer) 96 | 97 | # last one is not necessary 98 | if i < n_layers - 1: 99 | res_skip_channels = 2 * hidden_channels 100 | else: 101 | res_skip_channels = hidden_channels 102 | 103 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 104 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 105 | self.res_skip_layers.append(res_skip_layer) 106 | 107 | def forward(self, x, x_mask=None, g=None, **kwargs): 108 | output = torch.zeros_like(x) 109 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 110 | 111 | if g is not None: 112 | g = self.cond_layer(g) 113 | 114 | for i in range(self.n_layers): 115 | x_in = self.in_layers[i](x) 116 | x_in = self.drop(x_in) 117 | if g is not None: 118 | cond_offset = i * 2 * self.hidden_channels 119 | g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] 120 | else: 121 | g_l = torch.zeros_like(x_in) 122 | 123 | acts = commons.fused_add_tanh_sigmoid_multiply( 124 | x_in, 125 | g_l, 126 | n_channels_tensor) 127 | 128 | res_skip_acts = self.res_skip_layers[i](acts) 129 | if i < self.n_layers - 1: 130 | x = (x + res_skip_acts[:,:self.hidden_channels,:]) * x_mask 131 | output = output + res_skip_acts[:,self.hidden_channels:,:] 132 | else: 133 | output = output + res_skip_acts 134 | return output * x_mask 135 | 136 | def remove_weight_norm(self): 137 | if self.gin_channels != 0: 138 | torch.nn.utils.remove_weight_norm(self.cond_layer) 139 | for l in self.in_layers: 140 | torch.nn.utils.remove_weight_norm(l) 141 | for l in self.res_skip_layers: 142 | torch.nn.utils.remove_weight_norm(l) 143 | 144 | 145 | class ActNorm(nn.Module): 146 | def __init__(self, channels, ddi=False, **kwargs): 147 | super().__init__() 148 | self.channels = channels 149 | self.initialized = not ddi 150 | 151 | self.logs = nn.Parameter(torch.zeros(1, channels, 1)) 152 | self.bias = nn.Parameter(torch.zeros(1, channels, 1)) 153 | 154 | def forward(self, x, x_mask=None, reverse=False, **kwargs): 155 | if x_mask is None: 156 | x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype) 157 | x_len = torch.sum(x_mask, [1, 2]) 158 | if not self.initialized: 159 | self.initialize(x, x_mask) 160 | self.initialized = True 161 | 162 | if reverse: 163 | z = (x - self.bias) * torch.exp(-self.logs) * x_mask 164 | logdet = None 165 | else: 166 | z = (self.bias + torch.exp(self.logs) * x) * x_mask 167 | logdet = torch.sum(self.logs) * x_len # [b] 168 | 169 | return z, logdet 170 | 171 | def store_inverse(self): 172 | pass 173 | 174 | def set_ddi(self, ddi): 175 | self.initialized = not ddi 176 | 177 | def initialize(self, x, x_mask): 178 | with torch.no_grad(): 179 | denom = torch.sum(x_mask, [0, 2]) 180 | m = torch.sum(x * x_mask, [0, 2]) / denom 181 | m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom 182 | v = m_sq - (m ** 2) 183 | logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) 184 | 185 | bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) 186 | logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype) 187 | 188 | self.bias.data.copy_(bias_init) 189 | self.logs.data.copy_(logs_init) 190 | 191 | 192 | class InvConvNear(nn.Module): 193 | def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs): 194 | super().__init__() 195 | assert(n_split % 2 == 0) 196 | self.channels = channels 197 | self.n_split = n_split 198 | self.no_jacobian = no_jacobian 199 | 200 | w_init = torch.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0] 201 | if torch.det(w_init) < 0: 202 | w_init[:,0] = -1 * w_init[:,0] 203 | self.weight = nn.Parameter(w_init) 204 | 205 | def forward(self, x, x_mask=None, reverse=False, **kwargs): 206 | b, c, t = x.size() 207 | assert(c % self.n_split == 0) 208 | if x_mask is None: 209 | x_mask = 1 210 | x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t 211 | else: 212 | x_len = torch.sum(x_mask, [1, 2]) 213 | 214 | x = x.view(b, 2, c // self.n_split, self.n_split // 2, t) 215 | x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t) 216 | 217 | if reverse: 218 | if hasattr(self, "weight_inv"): 219 | weight = self.weight_inv 220 | else: 221 | weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) 222 | logdet = None 223 | else: 224 | weight = self.weight 225 | if self.no_jacobian: 226 | logdet = 0 227 | else: 228 | logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b] 229 | 230 | weight = weight.view(self.n_split, self.n_split, 1, 1) 231 | z = F.conv2d(x, weight) 232 | 233 | z = z.view(b, 2, self.n_split // 2, c // self.n_split, t) 234 | z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask 235 | return z, logdet 236 | 237 | def store_inverse(self): 238 | self.weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) 239 | 240 | -------------------------------------------------------------------------------- /waveglow/glow_old.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from glow import Invertible1x1Conv, remove 4 | 5 | 6 | @torch.jit.script 7 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 8 | n_channels_int = n_channels[0] 9 | in_act = input_a+input_b 10 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 11 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 12 | acts = t_act * s_act 13 | return acts 14 | 15 | 16 | class WN(torch.nn.Module): 17 | """ 18 | This is the WaveNet like layer for the affine coupling. The primary difference 19 | from WaveNet is the convolutions need not be causal. There is also no dilation 20 | size reset. The dilation only doubles on each layer 21 | """ 22 | def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels, 23 | kernel_size): 24 | super(WN, self).__init__() 25 | assert(kernel_size % 2 == 1) 26 | assert(n_channels % 2 == 0) 27 | self.n_layers = n_layers 28 | self.n_channels = n_channels 29 | self.in_layers = torch.nn.ModuleList() 30 | self.res_skip_layers = torch.nn.ModuleList() 31 | self.cond_layers = torch.nn.ModuleList() 32 | 33 | start = torch.nn.Conv1d(n_in_channels, n_channels, 1) 34 | start = torch.nn.utils.weight_norm(start, name='weight') 35 | self.start = start 36 | 37 | # Initializing last layer to 0 makes the affine coupling layers 38 | # do nothing at first. This helps with training stability 39 | end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1) 40 | end.weight.data.zero_() 41 | end.bias.data.zero_() 42 | self.end = end 43 | 44 | for i in range(n_layers): 45 | dilation = 2 ** i 46 | padding = int((kernel_size*dilation - dilation)/2) 47 | in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size, 48 | dilation=dilation, padding=padding) 49 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 50 | self.in_layers.append(in_layer) 51 | 52 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1) 53 | cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 54 | self.cond_layers.append(cond_layer) 55 | 56 | # last one is not necessary 57 | if i < n_layers - 1: 58 | res_skip_channels = 2*n_channels 59 | else: 60 | res_skip_channels = n_channels 61 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) 62 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 63 | self.res_skip_layers.append(res_skip_layer) 64 | 65 | def forward(self, forward_input): 66 | audio, spect = forward_input 67 | audio = self.start(audio) 68 | 69 | for i in range(self.n_layers): 70 | acts = fused_add_tanh_sigmoid_multiply( 71 | self.in_layers[i](audio), 72 | self.cond_layers[i](spect), 73 | torch.IntTensor([self.n_channels])) 74 | 75 | res_skip_acts = self.res_skip_layers[i](acts) 76 | if i < self.n_layers - 1: 77 | audio = res_skip_acts[:,:self.n_channels,:] + audio 78 | skip_acts = res_skip_acts[:,self.n_channels:,:] 79 | else: 80 | skip_acts = res_skip_acts 81 | 82 | if i == 0: 83 | output = skip_acts 84 | else: 85 | output = skip_acts + output 86 | return self.end(output) 87 | 88 | 89 | class WaveGlow(torch.nn.Module): 90 | def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, 91 | n_early_size, WN_config): 92 | super(WaveGlow, self).__init__() 93 | 94 | self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, 95 | n_mel_channels, 96 | 1024, stride=256) 97 | assert(n_group % 2 == 0) 98 | self.n_flows = n_flows 99 | self.n_group = n_group 100 | self.n_early_every = n_early_every 101 | self.n_early_size = n_early_size 102 | self.WN = torch.nn.ModuleList() 103 | self.convinv = torch.nn.ModuleList() 104 | 105 | n_half = int(n_group/2) 106 | 107 | # Set up layers with the right sizes based on how many dimensions 108 | # have been output already 109 | n_remaining_channels = n_group 110 | for k in range(n_flows): 111 | if k % self.n_early_every == 0 and k > 0: 112 | n_half = n_half - int(self.n_early_size/2) 113 | n_remaining_channels = n_remaining_channels - self.n_early_size 114 | self.convinv.append(Invertible1x1Conv(n_remaining_channels)) 115 | self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config)) 116 | self.n_remaining_channels = n_remaining_channels # Useful during inference 117 | 118 | def forward(self, forward_input): 119 | return None 120 | """ 121 | forward_input[0] = audio: batch x time 122 | forward_input[1] = upsamp_spectrogram: batch x n_cond_channels x time 123 | """ 124 | """ 125 | spect, audio = forward_input 126 | 127 | # Upsample spectrogram to size of audio 128 | spect = self.upsample(spect) 129 | assert(spect.size(2) >= audio.size(1)) 130 | if spect.size(2) > audio.size(1): 131 | spect = spect[:, :, :audio.size(1)] 132 | 133 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 134 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 135 | 136 | audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) 137 | output_audio = [] 138 | s_list = [] 139 | s_conv_list = [] 140 | 141 | for k in range(self.n_flows): 142 | if k%4 == 0 and k > 0: 143 | output_audio.append(audio[:,:self.n_multi,:]) 144 | audio = audio[:,self.n_multi:,:] 145 | 146 | # project to new basis 147 | audio, s = self.convinv[k](audio) 148 | s_conv_list.append(s) 149 | 150 | n_half = int(audio.size(1)/2) 151 | if k%2 == 0: 152 | audio_0 = audio[:,:n_half,:] 153 | audio_1 = audio[:,n_half:,:] 154 | else: 155 | audio_1 = audio[:,:n_half,:] 156 | audio_0 = audio[:,n_half:,:] 157 | 158 | output = self.nn[k]((audio_0, spect)) 159 | s = output[:, n_half:, :] 160 | b = output[:, :n_half, :] 161 | audio_1 = torch.exp(s)*audio_1 + b 162 | s_list.append(s) 163 | 164 | if k%2 == 0: 165 | audio = torch.cat([audio[:,:n_half,:], audio_1],1) 166 | else: 167 | audio = torch.cat([audio_1, audio[:,n_half:,:]], 1) 168 | output_audio.append(audio) 169 | return torch.cat(output_audio,1), s_list, s_conv_list 170 | """ 171 | 172 | def infer(self, spect, sigma=1.0): 173 | spect = self.upsample(spect) 174 | # trim conv artifacts. maybe pad spec to kernel multiple 175 | time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0] 176 | spect = spect[:, :, :-time_cutoff] 177 | 178 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 179 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 180 | 181 | if spect.type() == 'torch.cuda.HalfTensor': 182 | audio = torch.cuda.HalfTensor(spect.size(0), 183 | self.n_remaining_channels, 184 | spect.size(2)).normal_() 185 | else: 186 | audio = torch.cuda.FloatTensor(spect.size(0), 187 | self.n_remaining_channels, 188 | spect.size(2)).normal_() 189 | 190 | audio = torch.autograd.Variable(sigma*audio) 191 | 192 | for k in reversed(range(self.n_flows)): 193 | n_half = int(audio.size(1)/2) 194 | if k%2 == 0: 195 | audio_0 = audio[:,:n_half,:] 196 | audio_1 = audio[:,n_half:,:] 197 | else: 198 | audio_1 = audio[:,:n_half,:] 199 | audio_0 = audio[:,n_half:,:] 200 | 201 | output = self.WN[k]((audio_0, spect)) 202 | s = output[:, n_half:, :] 203 | b = output[:, :n_half, :] 204 | audio_1 = (audio_1 - b)/torch.exp(s) 205 | if k%2 == 0: 206 | audio = torch.cat([audio[:,:n_half,:], audio_1],1) 207 | else: 208 | audio = torch.cat([audio_1, audio[:,n_half:,:]], 1) 209 | 210 | audio = self.convinv[k](audio, reverse=True) 211 | 212 | if k%4 == 0 and k > 0: 213 | if spect.type() == 'torch.cuda.HalfTensor': 214 | z = torch.cuda.HalfTensor(spect.size(0), 215 | self.n_early_size, 216 | spect.size(2)).normal_() 217 | else: 218 | z = torch.cuda.FloatTensor(spect.size(0), 219 | self.n_early_size, 220 | spect.size(2)).normal_() 221 | audio = torch.cat((sigma*z, audio),1) 222 | 223 | return audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data 224 | 225 | @staticmethod 226 | def remove_weightnorm(model): 227 | waveglow = model 228 | for WN in waveglow.WN: 229 | WN.start = torch.nn.utils.remove_weight_norm(WN.start) 230 | WN.in_layers = remove(WN.in_layers) 231 | WN.cond_layers = remove(WN.cond_layers) 232 | WN.res_skip_layers = remove(WN.res_skip_layers) 233 | return waveglow 234 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.utils.data 5 | import os 6 | 7 | import commons 8 | from utils import load_wav_to_torch, load_filepaths_and_text 9 | from text import text_to_sequence, cmudict 10 | from text.symbols import symbols 11 | 12 | 13 | class TextMelLoader(torch.utils.data.Dataset): 14 | """ 15 | 1) loads audio,text pairs 16 | 2) normalizes text and converts them to sequences of one-hot vectors 17 | 3) computes mel-spectrograms from audio files. 18 | """ 19 | def __init__(self, audiopaths_and_text, hparams): 20 | self.hparams = hparams 21 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) 22 | self.text_cleaners = hparams.text_cleaners 23 | self.max_wav_value = hparams.max_wav_value 24 | self.sampling_rate = hparams.sampling_rate 25 | self.load_mel_from_disk = hparams.load_mel_from_disk 26 | self.add_noise = hparams.add_noise 27 | self.add_blank = getattr(hparams, "add_blank", False) # improved version 28 | if getattr(hparams, "cmudict_path", None) is not None: 29 | self.cmudict = cmudict.CMUDict(hparams.cmudict_path) 30 | self.stft = commons.TacotronSTFT( 31 | hparams.filter_length, hparams.hop_length, hparams.win_length, 32 | hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, 33 | hparams.mel_fmax) 34 | random.seed(1234) 35 | random.shuffle(self.audiopaths_and_text) 36 | 37 | def get_mel_text_pair(self, audiopath_and_text, prefix=""): 38 | # separate filename and text 39 | audiopath, text = audiopath_and_text[0], audiopath_and_text[1] 40 | if prefix != "": 41 | audiopath = os.path.join(prefix, audiopath) 42 | text = self.get_text(text) 43 | mel = self.get_mel(audiopath) 44 | return (text, mel) 45 | 46 | def get_mel(self, filename): 47 | if not self.load_mel_from_disk: 48 | audio, sampling_rate = load_wav_to_torch(filename) 49 | if sampling_rate != self.stft.sampling_rate: 50 | raise ValueError("{} {} SR doesn't match target {} SR".format( 51 | sampling_rate, self.stft.sampling_rate)) 52 | if self.add_noise: 53 | audio = audio + torch.rand_like(audio) 54 | audio_norm = audio / self.max_wav_value 55 | audio_norm = audio_norm.unsqueeze(0) 56 | melspec = self.stft.mel_spectrogram(audio_norm) 57 | melspec = torch.squeeze(melspec, 0) 58 | else: 59 | melspec = torch.from_numpy(np.load(filename)) 60 | assert melspec.size(0) == self.stft.n_mel_channels, ( 61 | 'Mel dimension mismatch: given {}, expected {}'.format( 62 | melspec.size(0), self.stft.n_mel_channels)) 63 | 64 | return melspec 65 | 66 | def get_text(self, text): 67 | text_norm = text_to_sequence(text, self.text_cleaners, getattr(self, "cmudict", None)) 68 | if self.add_blank: 69 | text_norm = commons.intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols) 70 | text_norm = torch.IntTensor(text_norm) 71 | return text_norm 72 | 73 | def __getitem__(self, index): 74 | return self.get_mel_text_pair(self.audiopaths_and_text[index], prefix=self.hparams.audio_path_prefix) 75 | 76 | def __len__(self): 77 | return len(self.audiopaths_and_text) 78 | 79 | 80 | class TextMelCollate(): 81 | """ Zero-pads model inputs and targets based on number of frames per step 82 | """ 83 | def __init__(self, n_frames_per_step=1): 84 | self.n_frames_per_step = n_frames_per_step 85 | 86 | def __call__(self, batch): 87 | """Collate's training batch from normalized text and mel-spectrogram 88 | PARAMS 89 | ------ 90 | batch: [text_normalized, mel_normalized] 91 | """ 92 | # Right zero-pad all one-hot text sequences to max input length 93 | input_lengths, ids_sorted_decreasing = torch.sort( 94 | torch.LongTensor([len(x[0]) for x in batch]), 95 | dim=0, descending=True) 96 | max_input_len = input_lengths[0] 97 | 98 | text_padded = torch.LongTensor(len(batch), max_input_len) 99 | text_padded.zero_() 100 | for i in range(len(ids_sorted_decreasing)): 101 | text = batch[ids_sorted_decreasing[i]][0] 102 | text_padded[i, :text.size(0)] = text 103 | 104 | # Right zero-pad mel-spec 105 | num_mels = batch[0][1].size(0) 106 | max_target_len = max([x[1].size(1) for x in batch]) 107 | if max_target_len % self.n_frames_per_step != 0: 108 | max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step 109 | assert max_target_len % self.n_frames_per_step == 0 110 | 111 | # include mel padded 112 | mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len) 113 | mel_padded.zero_() 114 | output_lengths = torch.LongTensor(len(batch)) 115 | for i in range(len(ids_sorted_decreasing)): 116 | mel = batch[ids_sorted_decreasing[i]][1] 117 | mel_padded[i, :, :mel.size(1)] = mel 118 | output_lengths[i] = mel.size(1) 119 | 120 | return text_padded, input_lengths, mel_padded, output_lengths 121 | 122 | 123 | """Multi speaker version""" 124 | class TextMelSpeakerLoader(torch.utils.data.Dataset): 125 | """ 126 | 1) loads audio, speaker_id, text pairs 127 | 2) normalizes text and converts them to sequences of one-hot vectors 128 | 3) computes mel-spectrograms from audio files. 129 | """ 130 | def __init__(self, audiopaths_sid_text, hparams): 131 | self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text) 132 | self.text_cleaners = hparams.text_cleaners 133 | self.max_wav_value = hparams.max_wav_value 134 | self.sampling_rate = hparams.sampling_rate 135 | self.load_mel_from_disk = hparams.load_mel_from_disk 136 | self.add_noise = hparams.add_noise 137 | self.add_blank = getattr(hparams, "add_blank", False) # improved version 138 | self.min_text_len = getattr(hparams, "min_text_len", 1) 139 | self.max_text_len = getattr(hparams, "max_text_len", 190) 140 | if getattr(hparams, "cmudict_path", None) is not None: 141 | self.cmudict = cmudict.CMUDict(hparams.cmudict_path) 142 | self.stft = commons.TacotronSTFT( 143 | hparams.filter_length, hparams.hop_length, hparams.win_length, 144 | hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, 145 | hparams.mel_fmax) 146 | 147 | self._filter_text_len() 148 | random.seed(1234) 149 | random.shuffle(self.audiopaths_sid_text) 150 | 151 | def _filter_text_len(self): 152 | audiopaths_sid_text_new = [] 153 | for audiopath, sid, text in self.audiopaths_sid_text: 154 | if self.min_text_len <= len(text) and len(text) <= self.max_text_len: 155 | audiopaths_sid_text_new.append([audiopath, sid, text]) 156 | self.audiopaths_sid_text = audiopaths_sid_text_new 157 | 158 | def get_mel_text_speaker_pair(self, audiopath_sid_text): 159 | # separate filename, speaker_id and text 160 | audiopath, sid, text = audiopath_sid_text[0], audiopath_sid_text[1], audiopath_sid_text[2] 161 | text = self.get_text(text) 162 | mel = self.get_mel(audiopath) 163 | sid = self.get_sid(sid) 164 | return (text, mel, sid) 165 | 166 | def get_mel(self, filename): 167 | if not self.load_mel_from_disk: 168 | audio, sampling_rate = load_wav_to_torch(filename) 169 | if sampling_rate != self.stft.sampling_rate: 170 | raise ValueError("{} {} SR doesn't match target {} SR".format( 171 | sampling_rate, self.stft.sampling_rate)) 172 | if self.add_noise: 173 | audio = audio + torch.rand_like(audio) 174 | audio_norm = audio / self.max_wav_value 175 | audio_norm = audio_norm.unsqueeze(0) 176 | melspec = self.stft.mel_spectrogram(audio_norm) 177 | melspec = torch.squeeze(melspec, 0) 178 | else: 179 | melspec = torch.from_numpy(np.load(filename)) 180 | assert melspec.size(0) == self.stft.n_mel_channels, ( 181 | 'Mel dimension mismatch: given {}, expected {}'.format( 182 | melspec.size(0), self.stft.n_mel_channels)) 183 | 184 | return melspec 185 | 186 | def get_text(self, text): 187 | text_norm = text_to_sequence(text, self.text_cleaners, getattr(self, "cmudict", None)) 188 | if self.add_blank: 189 | text_norm = commons.intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols) 190 | text_norm = torch.IntTensor(text_norm) 191 | return text_norm 192 | 193 | def get_sid(self, sid): 194 | sid = torch.IntTensor([int(sid)]) 195 | return sid 196 | 197 | def __getitem__(self, index): 198 | return self.get_mel_text_speaker_pair(self.audiopaths_sid_text[index]) 199 | 200 | def __len__(self): 201 | return len(self.audiopaths_sid_text) 202 | 203 | 204 | class TextMelSpeakerCollate(): 205 | """ Zero-pads model inputs and targets based on number of frames per step 206 | """ 207 | def __init__(self, n_frames_per_step=1): 208 | self.n_frames_per_step = n_frames_per_step 209 | 210 | def __call__(self, batch): 211 | """Collate's training batch from normalized text and mel-spectrogram 212 | PARAMS 213 | ------ 214 | batch: [text_normalized, mel_normalized] 215 | """ 216 | # Right zero-pad all one-hot text sequences to max input length 217 | input_lengths, ids_sorted_decreasing = torch.sort( 218 | torch.LongTensor([len(x[0]) for x in batch]), 219 | dim=0, descending=True) 220 | max_input_len = input_lengths[0] 221 | 222 | text_padded = torch.LongTensor(len(batch), max_input_len) 223 | text_padded.zero_() 224 | for i in range(len(ids_sorted_decreasing)): 225 | text = batch[ids_sorted_decreasing[i]][0] 226 | text_padded[i, :text.size(0)] = text 227 | 228 | # Right zero-pad mel-spec 229 | num_mels = batch[0][1].size(0) 230 | max_target_len = max([x[1].size(1) for x in batch]) 231 | if max_target_len % self.n_frames_per_step != 0: 232 | max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step 233 | assert max_target_len % self.n_frames_per_step == 0 234 | 235 | # include mel padded & sid 236 | mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len) 237 | mel_padded.zero_() 238 | output_lengths = torch.LongTensor(len(batch)) 239 | sid = torch.LongTensor(len(batch)) 240 | for i in range(len(ids_sorted_decreasing)): 241 | mel = batch[ids_sorted_decreasing[i]][1] 242 | mel_padded[i, :, :mel.size(1)] = mel 243 | output_lengths[i] = mel.size(1) 244 | sid[i] = batch[ids_sorted_decreasing[i]][2] 245 | 246 | return text_padded, input_lengths, mel_padded, output_lengths, sid 247 | -------------------------------------------------------------------------------- /attentions.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | import commons 9 | import modules 10 | from modules import LayerNorm 11 | 12 | 13 | class Encoder(nn.Module): 14 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=None, block_length=None, **kwargs): 15 | super().__init__() 16 | self.hidden_channels = hidden_channels 17 | self.filter_channels = filter_channels 18 | self.n_heads = n_heads 19 | self.n_layers = n_layers 20 | self.kernel_size = kernel_size 21 | self.p_dropout = p_dropout 22 | self.window_size = window_size 23 | self.block_length = block_length 24 | 25 | self.drop = nn.Dropout(p_dropout) 26 | self.attn_layers = nn.ModuleList() 27 | self.norm_layers_1 = nn.ModuleList() 28 | self.ffn_layers = nn.ModuleList() 29 | self.norm_layers_2 = nn.ModuleList() 30 | for i in range(self.n_layers): 31 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size, p_dropout=p_dropout, block_length=block_length)) 32 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 33 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) 34 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 35 | 36 | def forward(self, x, x_mask): 37 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 38 | for i in range(self.n_layers): 39 | x = x * x_mask 40 | y = self.attn_layers[i](x, x, attn_mask) 41 | y = self.drop(y) 42 | x = self.norm_layers_1[i](x + y) 43 | 44 | y = self.ffn_layers[i](x, x_mask) 45 | y = self.drop(y) 46 | x = self.norm_layers_2[i](x + y) 47 | x = x * x_mask 48 | return x 49 | 50 | 51 | class CouplingBlock(nn.Module): 52 | def __init__(self, in_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0, sigmoid_scale=False): 53 | super().__init__() 54 | self.in_channels = in_channels 55 | self.hidden_channels = hidden_channels 56 | self.kernel_size = kernel_size 57 | self.dilation_rate = dilation_rate 58 | self.n_layers = n_layers 59 | self.gin_channels = gin_channels 60 | self.p_dropout = p_dropout 61 | self.sigmoid_scale = sigmoid_scale 62 | 63 | start = torch.nn.Conv1d(in_channels//2, hidden_channels, 1) 64 | start = torch.nn.utils.weight_norm(start) 65 | self.start = start 66 | # Initializing last layer to 0 makes the affine coupling layers 67 | # do nothing at first. It helps to stabilze training. 68 | end = torch.nn.Conv1d(hidden_channels, in_channels, 1) 69 | end.weight.data.zero_() 70 | end.bias.data.zero_() 71 | self.end = end 72 | 73 | self.wn = modules.WN(in_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels, p_dropout) 74 | 75 | 76 | def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): 77 | b, c, t = x.size() 78 | if x_mask is None: 79 | x_mask = 1 80 | x_0, x_1 = x[:,:self.in_channels//2], x[:,self.in_channels//2:] 81 | 82 | x = self.start(x_0) * x_mask 83 | x = self.wn(x, x_mask, g) 84 | out = self.end(x) 85 | 86 | z_0 = x_0 87 | m = out[:, :self.in_channels//2, :] 88 | logs = out[:, self.in_channels//2:, :] 89 | if self.sigmoid_scale: 90 | logs = torch.log(1e-6 + torch.sigmoid(logs + 2)) 91 | 92 | if reverse: 93 | z_1 = (x_1 - m) * torch.exp(-logs) * x_mask 94 | logdet = None 95 | else: 96 | z_1 = (m + torch.exp(logs) * x_1) * x_mask 97 | logdet = torch.sum(logs * x_mask, [1, 2]) 98 | 99 | z = torch.cat([z_0, z_1], 1) 100 | return z, logdet 101 | 102 | def store_inverse(self): 103 | self.wn.remove_weight_norm() 104 | 105 | 106 | class MultiHeadAttention(nn.Module): 107 | def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0., block_length=None, proximal_bias=False, proximal_init=False): 108 | super().__init__() 109 | assert channels % n_heads == 0 110 | 111 | self.channels = channels 112 | self.out_channels = out_channels 113 | self.n_heads = n_heads 114 | self.window_size = window_size 115 | self.heads_share = heads_share 116 | self.block_length = block_length 117 | self.proximal_bias = proximal_bias 118 | self.p_dropout = p_dropout 119 | self.attn = None 120 | 121 | self.k_channels = channels // n_heads 122 | self.conv_q = nn.Conv1d(channels, channels, 1) 123 | self.conv_k = nn.Conv1d(channels, channels, 1) 124 | self.conv_v = nn.Conv1d(channels, channels, 1) 125 | if window_size is not None: 126 | n_heads_rel = 1 if heads_share else n_heads 127 | rel_stddev = self.k_channels**-0.5 128 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 129 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 130 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 131 | self.drop = nn.Dropout(p_dropout) 132 | 133 | nn.init.xavier_uniform_(self.conv_q.weight) 134 | nn.init.xavier_uniform_(self.conv_k.weight) 135 | if proximal_init: 136 | self.conv_k.weight.data.copy_(self.conv_q.weight.data) 137 | self.conv_k.bias.data.copy_(self.conv_q.bias.data) 138 | nn.init.xavier_uniform_(self.conv_v.weight) 139 | 140 | def forward(self, x, c, attn_mask=None): 141 | q = self.conv_q(x) 142 | k = self.conv_k(c) 143 | v = self.conv_v(c) 144 | 145 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 146 | 147 | x = self.conv_o(x) 148 | return x 149 | 150 | def attention(self, query, key, value, mask=None): 151 | # reshape [b, d, t] -> [b, n_h, t, d_k] 152 | b, d, t_s, t_t = (*key.size(), query.size(2)) 153 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 154 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 155 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 156 | 157 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) 158 | if self.window_size is not None: 159 | assert t_s == t_t, "Relative attention is only available for self-attention." 160 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 161 | rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) 162 | rel_logits = self._relative_position_to_absolute_position(rel_logits) 163 | scores_local = rel_logits / math.sqrt(self.k_channels) 164 | scores = scores + scores_local 165 | if self.proximal_bias: 166 | assert t_s == t_t, "Proximal bias is only available for self-attention." 167 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) 168 | if mask is not None: 169 | scores = scores.masked_fill(mask == 0, -1e4) 170 | if self.block_length is not None: 171 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) 172 | scores = scores * block_mask + -1e4*(1 - block_mask) 173 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 174 | p_attn = self.drop(p_attn) 175 | output = torch.matmul(p_attn, value) 176 | if self.window_size is not None: 177 | relative_weights = self._absolute_position_to_relative_position(p_attn) 178 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 179 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) 180 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] 181 | return output, p_attn 182 | 183 | def _matmul_with_relative_values(self, x, y): 184 | """ 185 | x: [b, h, l, m] 186 | y: [h or 1, m, d] 187 | ret: [b, h, l, d] 188 | """ 189 | ret = torch.matmul(x, y.unsqueeze(0)) 190 | return ret 191 | 192 | def _matmul_with_relative_keys(self, x, y): 193 | """ 194 | x: [b, h, l, d] 195 | y: [h or 1, m, d] 196 | ret: [b, h, l, m] 197 | """ 198 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 199 | return ret 200 | 201 | def _get_relative_embeddings(self, relative_embeddings, length): 202 | max_relative_position = 2 * self.window_size + 1 203 | # Pad first before slice to avoid using cond ops. 204 | pad_length = max(length - (self.window_size + 1), 0) 205 | slice_start_position = max((self.window_size + 1) - length, 0) 206 | slice_end_position = slice_start_position + 2 * length - 1 207 | if pad_length > 0: 208 | padded_relative_embeddings = F.pad( 209 | relative_embeddings, 210 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) 211 | else: 212 | padded_relative_embeddings = relative_embeddings 213 | used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position] 214 | return used_relative_embeddings 215 | 216 | def _relative_position_to_absolute_position(self, x): 217 | """ 218 | x: [b, h, l, 2*l-1] 219 | ret: [b, h, l, l] 220 | """ 221 | batch, heads, length, _ = x.size() 222 | # Concat columns of pad to shift from relative to absolute indexing. 223 | x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) 224 | 225 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 226 | x_flat = x.view([batch, heads, length * 2 * length]) 227 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]])) 228 | 229 | # Reshape and slice out the padded elements. 230 | x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] 231 | return x_final 232 | 233 | def _absolute_position_to_relative_position(self, x): 234 | """ 235 | x: [b, h, l, l] 236 | ret: [b, h, l, 2*l-1] 237 | """ 238 | batch, heads, length, _ = x.size() 239 | # padd along column 240 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) 241 | x_flat = x.view([batch, heads, length**2 + length*(length -1)]) 242 | # add 0's in the beginning that will skew the elements after reshape 243 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 244 | x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] 245 | return x_final 246 | 247 | def _attention_bias_proximal(self, length): 248 | """Bias for self-attention to encourage attention to close positions. 249 | Args: 250 | length: an integer scalar. 251 | Returns: 252 | a Tensor with shape [1, 1, length, length] 253 | """ 254 | r = torch.arange(length, dtype=torch.float32) 255 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 256 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 257 | 258 | 259 | class FFN(nn.Module): 260 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None): 261 | super().__init__() 262 | self.in_channels = in_channels 263 | self.out_channels = out_channels 264 | self.filter_channels = filter_channels 265 | self.kernel_size = kernel_size 266 | self.p_dropout = p_dropout 267 | self.activation = activation 268 | 269 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2) 270 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size//2) 271 | self.drop = nn.Dropout(p_dropout) 272 | 273 | def forward(self, x, x_mask): 274 | x = self.conv_1(x * x_mask) 275 | if self.activation == "gelu": 276 | x = x * torch.sigmoid(1.702 * x) 277 | else: 278 | x = torch.relu(x) 279 | x = self.drop(x) 280 | x = self.conv_2(x * x_mask) 281 | return x * x_mask 282 | 283 | -------------------------------------------------------------------------------- /filelists/ljs_audio_text_val_filelist.txt: -------------------------------------------------------------------------------- 1 | DUMMY/LJ022-0023.wav|The overwhelming majority of people in this country know how to sift the wheat from the chaff in what they hear and what they read. 2 | DUMMY/LJ043-0030.wav|If somebody did that to me, a lousy trick like that, to take my wife away, and all the furniture, I would be mad as hell, too. 3 | DUMMY/LJ005-0201.wav|as is shown by the report of the Commissioners to inquire into the state of the municipal corporations in eighteen thirty-five. 4 | DUMMY/LJ001-0110.wav|Even the Caslon type when enlarged shows great shortcomings in this respect: 5 | DUMMY/LJ003-0345.wav|All the committee could do in this respect was to throw the responsibility on others. 6 | DUMMY/LJ007-0154.wav|These pungent and well-grounded strictures applied with still greater force to the unconvicted prisoner, the man who came to the prison innocent, and still uncontaminated, 7 | DUMMY/LJ018-0098.wav|and recognized as one of the frequenters of the bogus law-stationers. His arrest led to that of others. 8 | DUMMY/LJ047-0044.wav|Oswald was, however, willing to discuss his contacts with Soviet authorities. He denied having any involvement with Soviet intelligence agencies 9 | DUMMY/LJ031-0038.wav|The first physician to see the President at Parkland Hospital was Dr. Charles J. Carrico, a resident in general surgery. 10 | DUMMY/LJ048-0194.wav|during the morning of November twenty-two prior to the motorcade. 11 | DUMMY/LJ049-0026.wav|On occasion the Secret Service has been permitted to have an agent riding in the passenger compartment with the President. 12 | DUMMY/LJ004-0152.wav|although at Mr. Buxton's visit a new jail was in process of erection, the first step towards reform since Howard's visitation in seventeen seventy-four. 13 | DUMMY/LJ008-0278.wav|or theirs might be one of many, and it might be considered necessary to "make an example." 14 | DUMMY/LJ043-0002.wav|The Warren Commission Report. By The President's Commission on the Assassination of President Kennedy. Chapter seven. Lee Harvey Oswald: 15 | DUMMY/LJ009-0114.wav|Mr. Wakefield winds up his graphic but somewhat sensational account by describing another religious service, which may appropriately be inserted here. 16 | DUMMY/LJ028-0506.wav|A modern artist would have difficulty in doing such accurate work. 17 | DUMMY/LJ050-0168.wav|with the particular purposes of the agency involved. The Commission recognizes that this is a controversial area 18 | DUMMY/LJ039-0223.wav|Oswald's Marine training in marksmanship, his other rifle experience and his established familiarity with this particular weapon 19 | DUMMY/LJ029-0032.wav|According to O'Donnell, quote, we had a motorcade wherever we went, end quote. 20 | DUMMY/LJ031-0070.wav|Dr. Clark, who most closely observed the head wound, 21 | DUMMY/LJ034-0198.wav|Euins, who was on the southwest corner of Elm and Houston Streets testified that he could not describe the man he saw in the window. 22 | DUMMY/LJ026-0068.wav|Energy enters the plant, to a small extent, 23 | DUMMY/LJ039-0075.wav|once you know that you must put the crosshairs on the target and that is all that is necessary. 24 | DUMMY/LJ004-0096.wav|the fatal consequences whereof might be prevented if the justices of the peace were duly authorized 25 | DUMMY/LJ005-0014.wav|Speaking on a debate on prison matters, he declared that 26 | DUMMY/LJ012-0161.wav|he was reported to have fallen away to a shadow. 27 | DUMMY/LJ018-0239.wav|His disappearance gave color and substance to evil reports already in circulation that the will and conveyance above referred to 28 | DUMMY/LJ019-0257.wav|Here the tread-wheel was in use, there cellular cranks, or hard-labor machines. 29 | DUMMY/LJ028-0008.wav|you tap gently with your heel upon the shoulder of the dromedary to urge her on. 30 | DUMMY/LJ024-0083.wav|This plan of mine is no attack on the Court; 31 | DUMMY/LJ042-0129.wav|No night clubs or bowling alleys, no places of recreation except the trade union dances. I have had enough. 32 | DUMMY/LJ036-0103.wav|The police asked him whether he could pick out his passenger from the lineup. 33 | DUMMY/LJ046-0058.wav|During his Presidency, Franklin D. Roosevelt made almost four hundred journeys and traveled more than three hundred fifty thousand miles. 34 | DUMMY/LJ014-0076.wav|He was seen afterwards smoking and talking with his hosts in their back parlor, and never seen again alive. 35 | DUMMY/LJ002-0043.wav|long narrow rooms -- one thirty-six feet, six twenty-three feet, and the eighth eighteen, 36 | DUMMY/LJ009-0076.wav|We come to the sermon. 37 | DUMMY/LJ017-0131.wav|even when the high sheriff had told him there was no possibility of a reprieve, and within a few hours of execution. 38 | DUMMY/LJ046-0184.wav|but there is a system for the immediate notification of the Secret Service by the confining institution when a subject is released or escapes. 39 | DUMMY/LJ014-0263.wav|When other pleasures palled he took a theatre, and posed as a munificent patron of the dramatic art. 40 | DUMMY/LJ042-0096.wav|(old exchange rate) in addition to his factory salary of approximately equal amount 41 | DUMMY/LJ049-0050.wav|Hill had both feet on the car and was climbing aboard to assist President and Mrs. Kennedy. 42 | DUMMY/LJ019-0186.wav|seeing that since the establishment of the Central Criminal Court, Newgate received prisoners for trial from several counties, 43 | DUMMY/LJ028-0307.wav|then let twenty days pass, and at the end of that time station near the Chaldasan gates a body of four thousand. 44 | DUMMY/LJ012-0235.wav|While they were in a state of insensibility the murder was committed. 45 | DUMMY/LJ034-0053.wav|reached the same conclusion as Latona that the prints found on the cartons were those of Lee Harvey Oswald. 46 | DUMMY/LJ014-0030.wav|These were damnatory facts which well supported the prosecution. 47 | DUMMY/LJ015-0203.wav|but were the precautions too minute, the vigilance too close to be eluded or overcome? 48 | DUMMY/LJ028-0093.wav|but his scribe wrote it in the manner customary for the scribes of those days to write of their royal masters. 49 | DUMMY/LJ002-0018.wav|The inadequacy of the jail was noticed and reported upon again and again by the grand juries of the city of London, 50 | DUMMY/LJ028-0275.wav|At last, in the twentieth month, 51 | DUMMY/LJ012-0042.wav|which he kept concealed in a hiding-place with a trap-door just under his bed. 52 | DUMMY/LJ011-0096.wav|He married a lady also belonging to the Society of Friends, who brought him a large fortune, which, and his own money, he put into a city firm, 53 | DUMMY/LJ036-0077.wav|Roger D. Craig, a deputy sheriff of Dallas County, 54 | DUMMY/LJ016-0318.wav|Other officials, great lawyers, governors of prisons, and chaplains supported this view. 55 | DUMMY/LJ013-0164.wav|who came from his room ready dressed, a suspicious circumstance, as he was always late in the morning. 56 | DUMMY/LJ027-0141.wav|is closely reproduced in the life-history of existing deer. Or, in other words, 57 | DUMMY/LJ028-0335.wav|accordingly they committed to him the command of their whole army, and put the keys of their city into his hands. 58 | DUMMY/LJ031-0202.wav|Mrs. Kennedy chose the hospital in Bethesda for the autopsy because the President had served in the Navy. 59 | DUMMY/LJ021-0145.wav|From those willing to join in establishing this hoped-for period of peace, 60 | DUMMY/LJ016-0288.wav|"Müller, Müller, He's the man," till a diversion was created by the appearance of the gallows, which was received with continuous yells. 61 | DUMMY/LJ028-0081.wav|Years later, when the archaeologists could readily distinguish the false from the true, 62 | DUMMY/LJ018-0081.wav|his defense being that he had intended to commit suicide, but that, on the appearance of this officer who had wronged him, 63 | DUMMY/LJ021-0066.wav|together with a great increase in the payrolls, there has come a substantial rise in the total of industrial profits 64 | DUMMY/LJ009-0238.wav|After this the sheriffs sent for another rope, but the spectators interfered, and the man was carried back to jail. 65 | DUMMY/LJ005-0079.wav|and improve the morals of the prisoners, and shall insure the proper measure of punishment to convicted offenders. 66 | DUMMY/LJ035-0019.wav|drove to the northwest corner of Elm and Houston, and parked approximately ten feet from the traffic signal. 67 | DUMMY/LJ036-0174.wav|This is the approximate time he entered the roominghouse, according to Earlene Roberts, the housekeeper there. 68 | DUMMY/LJ046-0146.wav|The criteria in effect prior to November twenty-two, nineteen sixty-three, for determining whether to accept material for the PRS general files 69 | DUMMY/LJ017-0044.wav|and the deepest anxiety was felt that the crime, if crime there had been, should be brought home to its perpetrator. 70 | DUMMY/LJ017-0070.wav|but his sporting operations did not prosper, and he became a needy man, always driven to desperate straits for cash. 71 | DUMMY/LJ014-0020.wav|He was soon afterwards arrested on suspicion, and a search of his lodgings brought to light several garments saturated with blood; 72 | DUMMY/LJ016-0020.wav|He never reached the cistern, but fell back into the yard, injuring his legs severely. 73 | DUMMY/LJ045-0230.wav|when he was finally apprehended in the Texas Theatre. Although it is not fully corroborated by others who were present, 74 | DUMMY/LJ035-0129.wav|and she must have run down the stairs ahead of Oswald and would probably have seen or heard him. 75 | DUMMY/LJ008-0307.wav|afterwards express a wish to murder the Recorder for having kept them so long in suspense. 76 | DUMMY/LJ008-0294.wav|nearly indefinitely deferred. 77 | DUMMY/LJ047-0148.wav|On October twenty-five, 78 | DUMMY/LJ008-0111.wav|They entered a "stone cold room," and were presently joined by the prisoner. 79 | DUMMY/LJ034-0042.wav|that he could only testify with certainty that the print was less than three days old. 80 | DUMMY/LJ037-0234.wav|Mrs. Mary Brock, the wife of a mechanic who worked at the station, was there at the time and she saw a white male, 81 | DUMMY/LJ040-0002.wav|Chapter seven. Lee Harvey Oswald: Background and Possible Motives, Part one. 82 | DUMMY/LJ045-0140.wav|The arguments he used to justify his use of the alias suggest that Oswald may have come to think that the whole world was becoming involved 83 | DUMMY/LJ012-0035.wav|the number and names on watches, were carefully removed or obliterated after the goods passed out of his hands. 84 | DUMMY/LJ012-0250.wav|On the seventh July, eighteen thirty-seven, 85 | DUMMY/LJ016-0179.wav|contracted with sheriffs and conveners to work by the job. 86 | DUMMY/LJ016-0138.wav|at a distance from the prison. 87 | DUMMY/LJ027-0052.wav|These principles of homology are essential to a correct interpretation of the facts of morphology. 88 | DUMMY/LJ031-0134.wav|On one occasion Mrs. Johnson, accompanied by two Secret Service agents, left the room to see Mrs. Kennedy and Mrs. Connally. 89 | DUMMY/LJ019-0273.wav|which Sir Joshua Jebb told the committee he considered the proper elements of penal discipline. 90 | DUMMY/LJ014-0110.wav|At the first the boxes were impounded, opened, and found to contain many of O'Connor's effects. 91 | DUMMY/LJ034-0160.wav|on Brennan's subsequent certain identification of Lee Harvey Oswald as the man he saw fire the rifle. 92 | DUMMY/LJ038-0199.wav|eleven. If I am alive and taken prisoner, 93 | DUMMY/LJ014-0010.wav|yet he could not overcome the strange fascination it had for him, and remained by the side of the corpse till the stretcher came. 94 | DUMMY/LJ033-0047.wav|I noticed when I went out that the light was on, end quote, 95 | DUMMY/LJ040-0027.wav|He was never satisfied with anything. 96 | DUMMY/LJ048-0228.wav|and others who were present say that no agent was inebriated or acted improperly. 97 | DUMMY/LJ003-0111.wav|He was in consequence put out of the protection of their internal law, end quote. Their code was a subject of some curiosity. 98 | DUMMY/LJ008-0258.wav|Let me retrace my steps, and speak more in detail of the treatment of the condemned in those bloodthirsty and brutally indifferent days, 99 | DUMMY/LJ029-0022.wav|The original plan called for the President to spend only one day in the State, making whirlwind visits to Dallas, Fort Worth, San Antonio, and Houston. 100 | DUMMY/LJ004-0045.wav|Mr. Sturges Bourne, Sir James Mackintosh, Sir James Scarlett, and William Wilberforce. 101 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import math 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | import modules 9 | import commons 10 | import attentions 11 | import monotonic_align 12 | import unet 13 | 14 | class DurationPredictor(nn.Module): 15 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): 16 | super().__init__() 17 | 18 | self.in_channels = in_channels 19 | self.filter_channels = filter_channels 20 | self.kernel_size = kernel_size 21 | self.p_dropout = p_dropout 22 | 23 | self.drop = nn.Dropout(p_dropout) 24 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2) 25 | self.norm_1 = attentions.LayerNorm(filter_channels) 26 | self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2) 27 | self.norm_2 = attentions.LayerNorm(filter_channels) 28 | self.proj = nn.Conv1d(filter_channels, 1, 1) 29 | 30 | def forward(self, x, x_mask): 31 | x = self.conv_1(x * x_mask) 32 | x = torch.relu(x) 33 | x = self.norm_1(x) 34 | x = self.drop(x) 35 | x = self.conv_2(x * x_mask) 36 | x = torch.relu(x) 37 | x = self.norm_2(x) 38 | x = self.drop(x) 39 | x = self.proj(x * x_mask) 40 | return x * x_mask 41 | 42 | 43 | class TextEncoder(nn.Module): 44 | def __init__(self, 45 | n_vocab, 46 | out_channels, 47 | hidden_channels, 48 | filter_channels, 49 | filter_channels_dp, 50 | n_heads, 51 | n_layers, 52 | kernel_size, 53 | p_dropout, 54 | window_size=None, 55 | block_length=None, 56 | mean_only=False, 57 | prenet=False, 58 | gin_channels=0): 59 | 60 | super().__init__() 61 | 62 | self.n_vocab = n_vocab 63 | self.out_channels = out_channels 64 | self.hidden_channels = hidden_channels 65 | self.filter_channels = filter_channels 66 | self.filter_channels_dp = filter_channels_dp 67 | self.n_heads = n_heads 68 | self.n_layers = n_layers 69 | self.kernel_size = kernel_size 70 | self.p_dropout = p_dropout 71 | self.window_size = window_size 72 | self.block_length = block_length 73 | self.mean_only = mean_only 74 | self.prenet = prenet 75 | self.gin_channels = gin_channels 76 | 77 | self.emb = nn.Embedding(n_vocab, hidden_channels) 78 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) 79 | 80 | if prenet: 81 | self.pre = modules.ConvReluNorm(hidden_channels, hidden_channels, hidden_channels, kernel_size=5, n_layers=3, p_dropout=0.5) 82 | self.encoder = attentions.Encoder( 83 | hidden_channels, 84 | filter_channels, 85 | n_heads, 86 | n_layers, 87 | kernel_size, 88 | p_dropout, 89 | window_size=window_size, 90 | block_length=block_length, 91 | ) 92 | 93 | self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1) 94 | if not mean_only: 95 | self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1) 96 | self.proj_w = DurationPredictor(hidden_channels + gin_channels, filter_channels_dp, kernel_size, p_dropout) 97 | 98 | def forward(self, x, x_lengths, g=None): 99 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] 100 | x = torch.transpose(x, 1, -1) # [b, h, t] 101 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 102 | 103 | if self.prenet: 104 | x = self.pre(x, x_mask) 105 | x = self.encoder(x, x_mask) 106 | 107 | if g is not None: 108 | g_exp = g.expand(-1, -1, x.size(-1)) 109 | x_dp = torch.cat([torch.detach(x), g_exp], 1) 110 | else: 111 | x_dp = torch.detach(x) 112 | 113 | x_m = self.proj_m(x) * x_mask 114 | if not self.mean_only: 115 | x_logs = self.proj_s(x) * x_mask 116 | else: 117 | x_logs = torch.zeros_like(x_m) 118 | 119 | logw = self.proj_w(x_dp, x_mask) 120 | return x_m, x_logs, logw, x_mask 121 | 122 | 123 | class DiffusionDecoder(nn.Module): 124 | def __init__(self, 125 | unet_channels=64, 126 | unet_in_channels=2, 127 | unet_out_channels=1, 128 | dim_mults=(1, 2, 4), 129 | groups=8, 130 | with_time_emb=True, 131 | beta_0=0.05, 132 | beta_1=20, 133 | N=1000, 134 | T=1): 135 | 136 | super().__init__() 137 | 138 | self.beta_0 = beta_0 139 | self.beta_1 = beta_1 140 | self.N = N 141 | self.T = T 142 | self.delta_t = T*1.0 / N 143 | self.discrete_betas = torch.linspace(beta_0, beta_1, N) 144 | self.unet = unet.Unet(dim=unet_channels, out_dim=unet_out_channels, dim_mults=dim_mults, groups=groups, channels=unet_in_channels, with_time_emb=with_time_emb) 145 | 146 | def marginal_prob(self, mu, x, t): 147 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 148 | mean = torch.exp(log_mean_coeff[:, None, None]) * x + (1-torch.exp(log_mean_coeff[:, None, None]) ) * mu 149 | std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) 150 | return mean, std 151 | 152 | def cal_loss(self, x, mu, t, z, std, g=None): 153 | time_steps = t * (self.N - 1) 154 | if g: 155 | x = torch.stack([x, mu, g], 1) 156 | else: 157 | x = torch.stack([x, mu], 1) 158 | grad = self.unet(x, time_steps) 159 | loss = torch.square(grad + z / std[:, None, None]) * torch.square(std[:, None, None]) 160 | return loss 161 | 162 | def forward(self, mu, y=None, g=None, gen=False): 163 | if not gen: 164 | t = torch.FloatTensor(y.shape[0]).uniform_(0, self.T-self.delta_t).to(y.device)+self.delta_t # sample a random t 165 | mean, std = self.marginal_prob(mu, y, t) 166 | z = torch.randn_like(y) 167 | x = mean + std[:, None, None] * z 168 | loss = self.cal_loss(x, mu, t, z, std, g) 169 | return loss 170 | else: 171 | with torch.no_grad(): 172 | y_T = torch.randn_like(mu) + mu 173 | y_t_plus_one = y_T 174 | y_t = None 175 | for n in tqdm(range(self.N - 1, 0, -1)): 176 | t = torch.FloatTensor(1).fill_(n).to(mu.device) 177 | if g: 178 | x = torch.stack([y_t_plus_one, mu, g], 1) 179 | else: 180 | x = torch.stack([y_t_plus_one, mu], 1) 181 | grad = self.unet(x, t) 182 | y_t = y_t_plus_one-0.5*self.delta_t*self.discrete_betas[n]*(mu-y_t_plus_one-grad) 183 | y_t_plus_one = y_t 184 | 185 | return y_t 186 | 187 | 188 | class DiffusionGenerator(nn.Module): 189 | def __init__(self, 190 | n_vocab, 191 | hidden_channels, 192 | filter_channels, 193 | filter_channels_dp, 194 | enc_out_channels, 195 | kernel_size=3, 196 | n_heads=2, 197 | n_layers_enc=6, 198 | p_dropout=0., 199 | n_speakers=0, 200 | gin_channels=0, 201 | window_size=None, 202 | block_length=None, 203 | mean_only=False, 204 | hidden_channels_enc=None, 205 | hidden_channels_dec=None, 206 | prenet=False, 207 | dec_unet_channels=64, 208 | dec_dim_mults=(1, 2, 4), 209 | dec_groups=8, 210 | dec_unet_in_channels=2, 211 | dec_unet_out_channels=1, 212 | dec_with_time_emb=True, 213 | beta_0=0.05, 214 | beta_1=20, 215 | N=1000, 216 | T=1, 217 | **kwargs): 218 | 219 | super().__init__() 220 | self.n_vocab = n_vocab 221 | self.hidden_channels = hidden_channels 222 | self.filter_channels = filter_channels 223 | self.filter_channels_dp = filter_channels_dp 224 | self.enc_out_channels = enc_out_channels 225 | self.dec_in_channels = enc_out_channels 226 | self.kernel_size = kernel_size 227 | self.n_heads = n_heads 228 | self.n_layers_enc = n_layers_enc 229 | self.p_dropout = p_dropout 230 | self.n_speakers = n_speakers 231 | self.gin_channels = enc_out_channels 232 | self.window_size = window_size 233 | self.block_length = block_length 234 | self.mean_only = mean_only 235 | self.hidden_channels_enc = hidden_channels_enc 236 | self.hidden_channels_dec = hidden_channels_dec 237 | self.prenet = prenet 238 | self.dec_unet_channels = dec_unet_channels 239 | self.dec_unet_in_channels = dec_unet_in_channels if self.n_speakers < 1 else dec_unet_in_channels+1 240 | self.dec_unet_out_channels = dec_unet_out_channels 241 | self.dec_dim_mults = dec_dim_mults 242 | self.dec_groups = dec_groups 243 | self.dec_with_time_emb = dec_with_time_emb 244 | self.beta_0 = beta_0 245 | self.beta_1 = beta_1 246 | self.N = N 247 | self.T = T 248 | 249 | self.encoder = TextEncoder( 250 | n_vocab, 251 | enc_out_channels, 252 | hidden_channels_enc or hidden_channels, 253 | filter_channels, 254 | filter_channels_dp, 255 | n_heads, 256 | n_layers_enc, 257 | kernel_size, 258 | p_dropout, 259 | window_size=window_size, 260 | block_length=block_length, 261 | mean_only=mean_only, 262 | prenet=prenet, 263 | gin_channels=gin_channels) 264 | 265 | self.decoder = DiffusionDecoder( 266 | unet_channels=self.dec_unet_channels, 267 | unet_in_channels=self.dec_unet_in_channels, 268 | unet_out_channels=self.dec_unet_out_channels, 269 | dim_mults=self.dec_dim_mults, 270 | groups=self.dec_groups, 271 | with_time_emb=self.dec_with_time_emb, 272 | beta_0=self.beta_0, 273 | beta_1=self.beta_1, 274 | N=self.N, 275 | T=self.T) 276 | 277 | if n_speakers > 1: 278 | self.emb_g = nn.Embedding(n_speakers, gin_channels) 279 | nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) 280 | 281 | def forward(self, x, x_lengths, y=None, y_lengths=None, g=None, gen=False, noise_scale=1., length_scale=1.): 282 | if g is not None: 283 | g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] 284 | x_m, x_logs, logw, x_mask = self.encoder(x, x_lengths, g=g) 285 | 286 | if gen: 287 | w = torch.exp(logw) * x_mask * length_scale 288 | w_ceil = torch.ceil(w) 289 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 290 | y_max_length = None 291 | else: 292 | y_max_length = y.size(2) 293 | #y, y_lengths, y_max_length = self.preprocess(y, y_lengths, y_max_length) 294 | z_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) 295 | attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(z_mask, 2) 296 | 297 | if gen: 298 | attn = commons.generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) 299 | z_m = torch.matmul(attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] 300 | z_logs = torch.matmul(attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)).transpose(1, 2) 301 | logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask 302 | 303 | y = self.decoder(z_m, gen=True) 304 | return (y, z_m, z_logs, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_) 305 | else: 306 | with torch.no_grad(): 307 | x_s_sq_r = torch.exp(-2 * x_logs) 308 | logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_logs, [1]).unsqueeze(-1) # [b, t, 1] 309 | logp2 = torch.matmul(x_s_sq_r.transpose(1,2), -0.5 * (y ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] 310 | logp3 = torch.matmul((x_m * x_s_sq_r).transpose(1,2), y) # [b, t, d] x [b, d, t'] = [b, t, t'] 311 | logp4 = torch.sum(-0.5 * (x_m ** 2) * x_s_sq_r, [1]).unsqueeze(-1) # [b, t, 1] 312 | logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] 313 | 314 | attn = monotonic_align.maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() 315 | z_m = torch.matmul(attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] 316 | z_logs = torch.matmul(attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] 317 | logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask 318 | grad_loss = self.decoder(mu=z_m, y=y, g=g, gen=False).mean() 319 | return grad_loss, (z_m, z_logs, z_mask), (attn, logw, logw_) 320 | -------------------------------------------------------------------------------- /waveglow/glow.py: -------------------------------------------------------------------------------- 1 | # ***************************************************************************** 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of the NVIDIA CORPORATION nor the 12 | # names of its contributors may be used to endorse or promote products 13 | # derived from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | # 26 | # ***************************************************************************** 27 | import copy 28 | import torch 29 | from torch.autograd import Variable 30 | import torch.nn.functional as F 31 | 32 | 33 | @torch.jit.script 34 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 35 | n_channels_int = n_channels[0] 36 | in_act = input_a+input_b 37 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 38 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 39 | acts = t_act * s_act 40 | return acts 41 | 42 | 43 | class WaveGlowLoss(torch.nn.Module): 44 | def __init__(self, sigma=1.0): 45 | super(WaveGlowLoss, self).__init__() 46 | self.sigma = sigma 47 | 48 | def forward(self, model_output): 49 | z, log_s_list, log_det_W_list = model_output 50 | for i, log_s in enumerate(log_s_list): 51 | if i == 0: 52 | log_s_total = torch.sum(log_s) 53 | log_det_W_total = log_det_W_list[i] 54 | else: 55 | log_s_total = log_s_total + torch.sum(log_s) 56 | log_det_W_total += log_det_W_list[i] 57 | 58 | loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total 59 | return loss/(z.size(0)*z.size(1)*z.size(2)) 60 | 61 | 62 | class Invertible1x1Conv(torch.nn.Module): 63 | """ 64 | The layer outputs both the convolution, and the log determinant 65 | of its weight matrix. If reverse=True it does convolution with 66 | inverse 67 | """ 68 | def __init__(self, c): 69 | super(Invertible1x1Conv, self).__init__() 70 | self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, 71 | bias=False) 72 | 73 | # Sample a random orthonormal matrix to initialize weights 74 | W = torch.qr(torch.FloatTensor(c, c).normal_())[0] 75 | 76 | # Ensure determinant is 1.0 not -1.0 77 | if torch.det(W) < 0: 78 | W[:,0] = -1*W[:,0] 79 | W = W.view(c, c, 1) 80 | self.conv.weight.data = W 81 | 82 | def forward(self, z, reverse=False): 83 | # shape 84 | batch_size, group_size, n_of_groups = z.size() 85 | 86 | W = self.conv.weight.squeeze() 87 | 88 | if reverse: 89 | if not hasattr(self, 'W_inverse'): 90 | # Reverse computation 91 | W_inverse = W.float().inverse() 92 | W_inverse = Variable(W_inverse[..., None]) 93 | if z.type() == 'torch.cuda.HalfTensor': 94 | W_inverse = W_inverse.half() 95 | self.W_inverse = W_inverse 96 | z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) 97 | return z 98 | else: 99 | # Forward computation 100 | log_det_W = batch_size * n_of_groups * torch.logdet(W) 101 | z = self.conv(z) 102 | return z, log_det_W 103 | 104 | 105 | class WN(torch.nn.Module): 106 | """ 107 | This is the WaveNet like layer for the affine coupling. The primary difference 108 | from WaveNet is the convolutions need not be causal. There is also no dilation 109 | size reset. The dilation only doubles on each layer 110 | """ 111 | def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels, 112 | kernel_size): 113 | super(WN, self).__init__() 114 | assert(kernel_size % 2 == 1) 115 | assert(n_channels % 2 == 0) 116 | self.n_layers = n_layers 117 | self.n_channels = n_channels 118 | self.in_layers = torch.nn.ModuleList() 119 | self.res_skip_layers = torch.nn.ModuleList() 120 | 121 | start = torch.nn.Conv1d(n_in_channels, n_channels, 1) 122 | start = torch.nn.utils.weight_norm(start, name='weight') 123 | self.start = start 124 | 125 | # Initializing last layer to 0 makes the affine coupling layers 126 | # do nothing at first. This helps with training stability 127 | end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1) 128 | end.weight.data.zero_() 129 | end.bias.data.zero_() 130 | self.end = end 131 | 132 | cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1) 133 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 134 | 135 | for i in range(n_layers): 136 | dilation = 2 ** i 137 | padding = int((kernel_size*dilation - dilation)/2) 138 | in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size, 139 | dilation=dilation, padding=padding) 140 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 141 | self.in_layers.append(in_layer) 142 | 143 | 144 | # last one is not necessary 145 | if i < n_layers - 1: 146 | res_skip_channels = 2*n_channels 147 | else: 148 | res_skip_channels = n_channels 149 | res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) 150 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 151 | self.res_skip_layers.append(res_skip_layer) 152 | 153 | def forward(self, forward_input): 154 | audio, spect = forward_input 155 | audio = self.start(audio) 156 | output = torch.zeros_like(audio) 157 | n_channels_tensor = torch.IntTensor([self.n_channels]) 158 | 159 | spect = self.cond_layer(spect) 160 | 161 | for i in range(self.n_layers): 162 | spect_offset = i*2*self.n_channels 163 | acts = fused_add_tanh_sigmoid_multiply( 164 | self.in_layers[i](audio), 165 | spect[:,spect_offset:spect_offset+2*self.n_channels,:], 166 | n_channels_tensor) 167 | 168 | res_skip_acts = self.res_skip_layers[i](acts) 169 | if i < self.n_layers - 1: 170 | audio = audio + res_skip_acts[:,:self.n_channels,:] 171 | output = output + res_skip_acts[:,self.n_channels:,:] 172 | else: 173 | output = output + res_skip_acts 174 | 175 | return self.end(output) 176 | 177 | 178 | class WaveGlow(torch.nn.Module): 179 | def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, 180 | n_early_size, WN_config): 181 | super(WaveGlow, self).__init__() 182 | 183 | self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, 184 | n_mel_channels, 185 | 1024, stride=256) 186 | assert(n_group % 2 == 0) 187 | self.n_flows = n_flows 188 | self.n_group = n_group 189 | self.n_early_every = n_early_every 190 | self.n_early_size = n_early_size 191 | self.WN = torch.nn.ModuleList() 192 | self.convinv = torch.nn.ModuleList() 193 | 194 | n_half = int(n_group/2) 195 | 196 | # Set up layers with the right sizes based on how many dimensions 197 | # have been output already 198 | n_remaining_channels = n_group 199 | for k in range(n_flows): 200 | if k % self.n_early_every == 0 and k > 0: 201 | n_half = n_half - int(self.n_early_size/2) 202 | n_remaining_channels = n_remaining_channels - self.n_early_size 203 | self.convinv.append(Invertible1x1Conv(n_remaining_channels)) 204 | self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config)) 205 | self.n_remaining_channels = n_remaining_channels # Useful during inference 206 | 207 | def forward(self, forward_input): 208 | """ 209 | forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames 210 | forward_input[1] = audio: batch x time 211 | """ 212 | spect, audio = forward_input 213 | 214 | # Upsample spectrogram to size of audio 215 | spect = self.upsample(spect) 216 | assert(spect.size(2) >= audio.size(1)) 217 | if spect.size(2) > audio.size(1): 218 | spect = spect[:, :, :audio.size(1)] 219 | 220 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 221 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 222 | 223 | audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) 224 | output_audio = [] 225 | log_s_list = [] 226 | log_det_W_list = [] 227 | 228 | for k in range(self.n_flows): 229 | if k % self.n_early_every == 0 and k > 0: 230 | output_audio.append(audio[:,:self.n_early_size,:]) 231 | audio = audio[:,self.n_early_size:,:] 232 | 233 | audio, log_det_W = self.convinv[k](audio) 234 | log_det_W_list.append(log_det_W) 235 | 236 | n_half = int(audio.size(1)/2) 237 | audio_0 = audio[:,:n_half,:] 238 | audio_1 = audio[:,n_half:,:] 239 | 240 | output = self.WN[k]((audio_0, spect)) 241 | log_s = output[:, n_half:, :] 242 | b = output[:, :n_half, :] 243 | audio_1 = torch.exp(log_s)*audio_1 + b 244 | log_s_list.append(log_s) 245 | 246 | audio = torch.cat([audio_0, audio_1],1) 247 | 248 | output_audio.append(audio) 249 | return torch.cat(output_audio,1), log_s_list, log_det_W_list 250 | 251 | def infer(self, spect, sigma=1.0): 252 | spect = self.upsample(spect) 253 | # trim conv artifacts. maybe pad spec to kernel multiple 254 | time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0] 255 | spect = spect[:, :, :-time_cutoff] 256 | 257 | spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) 258 | spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) 259 | 260 | if spect.type() == 'torch.cuda.HalfTensor': 261 | audio = torch.cuda.HalfTensor(spect.size(0), 262 | self.n_remaining_channels, 263 | spect.size(2)).normal_() 264 | else: 265 | audio = torch.cuda.FloatTensor(spect.size(0), 266 | self.n_remaining_channels, 267 | spect.size(2)).normal_() 268 | 269 | audio = torch.autograd.Variable(sigma*audio) 270 | 271 | for k in reversed(range(self.n_flows)): 272 | n_half = int(audio.size(1)/2) 273 | audio_0 = audio[:,:n_half,:] 274 | audio_1 = audio[:,n_half:,:] 275 | 276 | output = self.WN[k]((audio_0, spect)) 277 | 278 | s = output[:, n_half:, :] 279 | b = output[:, :n_half, :] 280 | audio_1 = (audio_1 - b)/torch.exp(s) 281 | audio = torch.cat([audio_0, audio_1],1) 282 | 283 | audio = self.convinv[k](audio, reverse=True) 284 | 285 | if k % self.n_early_every == 0 and k > 0: 286 | if spect.type() == 'torch.cuda.HalfTensor': 287 | z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() 288 | else: 289 | z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() 290 | audio = torch.cat((sigma*z, audio),1) 291 | 292 | audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data 293 | return audio 294 | 295 | @staticmethod 296 | def remove_weightnorm(model): 297 | waveglow = model 298 | for WN in waveglow.WN: 299 | WN.start = torch.nn.utils.remove_weight_norm(WN.start) 300 | WN.in_layers = remove(WN.in_layers) 301 | WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer) 302 | WN.res_skip_layers = remove(WN.res_skip_layers) 303 | return waveglow 304 | 305 | 306 | def remove(conv_list): 307 | new_conv_list = torch.nn.ModuleList() 308 | for old_conv in conv_list: 309 | old_conv = torch.nn.utils.remove_weight_norm(old_conv) 310 | new_conv_list.append(old_conv) 311 | return new_conv_list 312 | -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import copy 3 | import torch 4 | from torch import nn, einsum 5 | import torch.nn.functional as F 6 | from inspect import isfunction 7 | from functools import partial 8 | 9 | 10 | import numpy as np 11 | from tqdm import tqdm 12 | from einops import rearrange 13 | 14 | 15 | # helpers functions 16 | 17 | def exists(x): 18 | return x is not None 19 | 20 | def default(val, d): 21 | if exists(val): 22 | return val 23 | return d() if isfunction(d) else d 24 | 25 | def cycle(dl): 26 | while True: 27 | for data in dl: 28 | yield data 29 | 30 | def num_to_groups(num, divisor): 31 | groups = num // divisor 32 | remainder = num % divisor 33 | arr = [divisor] * groups 34 | if remainder > 0: 35 | arr.append(remainder) 36 | return arr 37 | 38 | def loss_backwards(fp16, loss, optimizer, **kwargs): 39 | if fp16: 40 | with amp.scale_loss(loss, optimizer) as scaled_loss: 41 | scaled_loss.backward(**kwargs) 42 | else: 43 | loss.backward(**kwargs) 44 | 45 | # small helper modules 46 | 47 | class EMA(): 48 | def __init__(self, beta): 49 | super().__init__() 50 | self.beta = beta 51 | 52 | def update_model_average(self, ma_model, current_model): 53 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 54 | old_weight, up_weight = ma_params.data, current_params.data 55 | ma_params.data = self.update_average(old_weight, up_weight) 56 | 57 | def update_average(self, old, new): 58 | if old is None: 59 | return new 60 | return old * self.beta + (1 - self.beta) * new 61 | 62 | class Residual(nn.Module): 63 | def __init__(self, fn): 64 | super().__init__() 65 | self.fn = fn 66 | 67 | def forward(self, x, *args, **kwargs): 68 | return self.fn(x, *args, **kwargs) + x 69 | 70 | class SinusoidalPosEmb(nn.Module): 71 | def __init__(self, dim): 72 | super().__init__() 73 | self.dim = dim 74 | 75 | def forward(self, x): 76 | device = x.device 77 | half_dim = self.dim // 2 78 | emb = math.log(10000) / (half_dim - 1) 79 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 80 | emb = x[:, None] * emb[None, :] 81 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 82 | return emb 83 | 84 | class Mish(nn.Module): 85 | def forward(self, x): 86 | return x * torch.tanh(F.softplus(x)) 87 | 88 | class Upsample(nn.Module): 89 | def __init__(self, dim): 90 | super().__init__() 91 | self.conv = nn.ConvTranspose2d(dim, dim, 4, 2, 1) 92 | 93 | def forward(self, x): 94 | return self.conv(x) 95 | 96 | class Downsample(nn.Module): 97 | def __init__(self, dim): 98 | super().__init__() 99 | self.conv = nn.Conv2d(dim, dim, 3, 2, 1) 100 | 101 | def forward(self, x): 102 | return self.conv(x) 103 | 104 | class PreNorm(nn.Module): 105 | def __init__(self, dim, fn): 106 | super().__init__() 107 | self.fn = fn 108 | self.norm = nn.InstanceNorm2d(dim, affine = True) 109 | 110 | def forward(self, x): 111 | x = self.norm(x) 112 | return self.fn(x) 113 | 114 | # building block modules 115 | 116 | class Block(nn.Module): 117 | def __init__(self, dim, dim_out, groups = 8): 118 | super().__init__() 119 | self.block = nn.Sequential( 120 | nn.Conv2d(dim, dim_out, 3, padding=1), 121 | nn.GroupNorm(groups, dim_out), 122 | Mish() 123 | ) 124 | def forward(self, x): 125 | return self.block(x) 126 | 127 | class ResnetBlock(nn.Module): 128 | def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): 129 | super().__init__() 130 | self.mlp = nn.Sequential( 131 | Mish(), 132 | nn.Linear(time_emb_dim, dim_out) 133 | ) if exists(time_emb_dim) else None 134 | 135 | self.block1 = Block(dim, dim_out) 136 | self.block2 = Block(dim_out, dim_out) 137 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 138 | 139 | def forward(self, x, time_emb): 140 | h = self.block1(x) 141 | 142 | if exists(self.mlp): 143 | # print('hmmm') 144 | h += self.mlp(time_emb)[:, :, None, None] 145 | 146 | h = self.block2(h) 147 | return h + self.res_conv(x) 148 | 149 | class LinearAttention(nn.Module): 150 | def __init__(self, dim, heads = 4, dim_head = 32): 151 | super().__init__() 152 | self.heads = heads 153 | hidden_dim = dim_head * heads 154 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 155 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 156 | 157 | def forward(self, x): 158 | b, c, h, w = x.shape 159 | qkv = self.to_qkv(x) 160 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 161 | k = k.softmax(dim=-1) 162 | context = torch.einsum('bhdn,bhen->bhde', k, v) 163 | out = torch.einsum('bhde,bhdn->bhen', context, q) 164 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 165 | return self.to_out(out) 166 | 167 | # model 168 | 169 | class Unet(nn.Module): 170 | def __init__( 171 | self, 172 | dim, 173 | out_dim = None, 174 | dim_mults=(1, 2, 4, 8), 175 | groups = 8, 176 | channels = 2, 177 | with_time_emb = True 178 | ): 179 | super().__init__() 180 | self.channels = channels 181 | 182 | dims = [channels, *map(lambda m: dim * m, dim_mults)] 183 | in_out = list(zip(dims[:-1], dims[1:])) 184 | 185 | if with_time_emb: 186 | time_dim = dim 187 | self.time_mlp = nn.Sequential( 188 | SinusoidalPosEmb(dim), 189 | nn.Linear(dim, dim * 4), 190 | Mish(), 191 | nn.Linear(dim * 4, dim) 192 | ) 193 | else: 194 | time_dim = None 195 | self.time_mlp = None 196 | 197 | self.downs = nn.ModuleList([]) 198 | self.ups = nn.ModuleList([]) 199 | num_resolutions = len(in_out) 200 | 201 | for ind, (dim_in, dim_out) in enumerate(in_out): 202 | is_last = ind >= (num_resolutions - 1) 203 | 204 | self.downs.append(nn.ModuleList([ 205 | ResnetBlock(dim_in, dim_out, time_emb_dim = time_dim), 206 | ResnetBlock(dim_out, dim_out, time_emb_dim = time_dim), 207 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 208 | Downsample(dim_out) if not is_last else nn.Identity() 209 | ])) 210 | 211 | mid_dim = dims[-1] 212 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim) 213 | self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) 214 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim) 215 | 216 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 217 | is_last = ind >= (num_resolutions - 1) 218 | 219 | self.ups.append(nn.ModuleList([ 220 | ResnetBlock(dim_out * 2, dim_in, time_emb_dim = time_dim), 221 | ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), 222 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 223 | Upsample(dim_in) if not is_last else nn.Identity() 224 | ])) 225 | 226 | out_dim = default(out_dim, channels) 227 | self.final_conv = nn.Sequential( 228 | Block(dim, dim), 229 | nn.Conv2d(dim, out_dim, 1) 230 | ) 231 | 232 | def forward(self, x, time): 233 | length = x.shape[-1] 234 | x = F.pad(x, (0, math.ceil(length/4)*4-length), "constant", 0) 235 | t = self.time_mlp(time) if exists(self.time_mlp) else None 236 | 237 | h = [] 238 | 239 | for resnet, resnet2, attn, downsample in self.downs: 240 | x = resnet(x, t) 241 | x = resnet2(x, t) 242 | x = attn(x) 243 | h.append(x) 244 | x = downsample(x) 245 | 246 | x = self.mid_block1(x, t) 247 | x = self.mid_attn(x) 248 | x = self.mid_block2(x, t) 249 | 250 | for resnet, resnet2, attn, upsample in self.ups: 251 | x = torch.cat((x, h.pop()), dim=1) 252 | x = resnet(x, t) 253 | x = resnet2(x, t) 254 | x = attn(x) 255 | x = upsample(x) 256 | 257 | return torch.squeeze(self.final_conv(x),1)[:, :, :length] 258 | 259 | # gaussian diffusion trainer class 260 | 261 | def extract(a, t, x_shape): 262 | b, *_ = t.shape 263 | out = a.gather(-1, t) 264 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 265 | 266 | def noise_like(shape, device, repeat=False): 267 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 268 | noise = lambda: torch.randn(shape, device=device) 269 | return repeat_noise() if repeat else noise() 270 | 271 | def cosine_beta_schedule(timesteps, s = 0.008): 272 | """ 273 | cosine schedule 274 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 275 | """ 276 | steps = timesteps + 1 277 | x = np.linspace(0, steps, steps) 278 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 279 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 280 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 281 | return np.clip(betas, a_min = 0, a_max = 0.999) 282 | 283 | class GaussianDiffusion(nn.Module): 284 | def __init__( 285 | self, 286 | denoise_fn, 287 | *, 288 | image_size, 289 | channels = 3, 290 | timesteps = 1000, 291 | loss_type = 'l1', 292 | betas = None 293 | ): 294 | super().__init__() 295 | self.channels = channels 296 | self.image_size = image_size 297 | self.denoise_fn = denoise_fn 298 | 299 | if exists(betas): 300 | betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas 301 | else: 302 | betas = cosine_beta_schedule(timesteps) 303 | 304 | alphas = 1. - betas 305 | alphas_cumprod = np.cumprod(alphas, axis=0) 306 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 307 | 308 | timesteps, = betas.shape 309 | self.num_timesteps = int(timesteps) 310 | self.loss_type = loss_type 311 | 312 | to_torch = partial(torch.tensor, dtype=torch.float32) 313 | 314 | self.register_buffer('betas', to_torch(betas)) 315 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 316 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 317 | 318 | # calculations for diffusion q(x_t | x_{t-1}) and others 319 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 320 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 321 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 322 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 323 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 324 | 325 | # calculations for posterior q(x_{t-1} | x_t, x_0) 326 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 327 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 328 | self.register_buffer('posterior_variance', to_torch(posterior_variance)) 329 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 330 | self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) 331 | self.register_buffer('posterior_mean_coef1', to_torch( 332 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) 333 | self.register_buffer('posterior_mean_coef2', to_torch( 334 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) 335 | 336 | def q_mean_variance(self, x_start, t): 337 | mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 338 | variance = extract(1. - self.alphas_cumprod, t, x_start.shape) 339 | log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) 340 | return mean, variance, log_variance 341 | 342 | def predict_start_from_noise(self, x_t, t, noise): 343 | return ( 344 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 345 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 346 | ) 347 | 348 | def q_posterior(self, x_start, x_t, t): 349 | posterior_mean = ( 350 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 351 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 352 | ) 353 | posterior_variance = extract(self.posterior_variance, t, x_t.shape) 354 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) 355 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 356 | 357 | def p_mean_variance(self, x, t, clip_denoised: bool): 358 | x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_fn(x, t)) 359 | 360 | if clip_denoised: 361 | x_recon.clamp_(-1., 1.) 362 | 363 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) 364 | return model_mean, posterior_variance, posterior_log_variance 365 | 366 | @torch.no_grad() 367 | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): 368 | b, *_, device = *x.shape, x.device 369 | model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) 370 | noise = noise_like(x.shape, device, repeat_noise) 371 | # no noise when t == 0 372 | nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) 373 | return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise 374 | 375 | @torch.no_grad() 376 | def p_sample_loop(self, shape): 377 | device = self.betas.device 378 | 379 | b = shape[0] 380 | img = torch.randn(shape, device=device) 381 | 382 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): 383 | img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) 384 | return img 385 | 386 | @torch.no_grad() 387 | def sample(self, batch_size = 16): 388 | image_size = self.image_size 389 | channels = self.channels 390 | return self.p_sample_loop((batch_size, channels, image_size, image_size)) 391 | 392 | @torch.no_grad() 393 | def interpolate(self, x1, x2, t = None, lam = 0.5): 394 | b, *_, device = *x1.shape, x1.device 395 | t = default(t, self.num_timesteps - 1) 396 | 397 | assert x1.shape == x2.shape 398 | 399 | t_batched = torch.stack([torch.tensor(t, device=device)] * b) 400 | xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) 401 | 402 | img = (1 - lam) * xt1 + lam * xt2 403 | for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t): 404 | img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) 405 | 406 | return img 407 | 408 | def q_sample(self, x_start, t, noise=None): 409 | noise = default(noise, lambda: torch.randn_like(x_start)) 410 | 411 | return ( 412 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 413 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 414 | ) 415 | 416 | def p_losses(self, x_start, t, noise = None): 417 | b, c, h, w = x_start.shape 418 | noise = default(noise, lambda: torch.randn_like(x_start)) 419 | 420 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) 421 | x_recon = self.denoise_fn(x_noisy, t) 422 | 423 | if self.loss_type == 'l1': 424 | loss = (noise - x_recon).abs().mean() 425 | elif self.loss_type == 'l2': 426 | loss = F.mse_loss(noise, x_recon) 427 | else: 428 | raise NotImplementedError() 429 | 430 | return loss 431 | 432 | def forward(self, x, *args, **kwargs): 433 | b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size 434 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}' 435 | t = torch.randint(0, self.num_timesteps, (b,), device=device).long() 436 | return self.p_losses(x, t, *args, **kwargs) 437 | 438 | 439 | 440 | if __name__ == '__main__': 441 | wg=Unet(64,dim_mults=(1, 2,4),out_dim=1) 442 | 443 | u=np.zeros([16,80,431],dtype=np.float32) 444 | x0=np.zeros([16,80,431],dtype=np.float32) 445 | 446 | u = torch.from_numpy(u) 447 | x0 = torch.from_numpy(x0) 448 | u = F.pad(u, (0,math.ceil(u.shape[2]/4)*4-u.shape[2]), "constant", 0) # effectively zero padding 449 | x0 = F.pad(x0,(0, math.ceil(x0.shape[2]/4)*4-x0.shape[2]), "constant", 0) # effectively zero padding 450 | print(u.shape) 451 | spectrogram = torch.stack((u,x0),dim=1) 452 | print(spectrogram.shape) 453 | N, T ,_,_ = spectrogram.shape 454 | S = 1000 455 | device = torch.device('cpu') 456 | s = torch.randint(1, S + 1, [N], device=device) 457 | 458 | x = wg(spectrogram,s) 459 | print(x.shape) 460 | 461 | 462 | --------------------------------------------------------------------------------