├── .gitignore ├── LICENSE ├── README.md ├── assets ├── gd.png └── lj-tensorboard.png ├── config └── default.yaml ├── datasets └── dataloader.py ├── hubconf.py ├── inference.py ├── model ├── discriminator.py ├── generator.py ├── identity.py ├── multiscale.py └── res_stack.py ├── preprocess.py ├── requirements.txt ├── trainer.py └── utils ├── audio_processing.py ├── hparams.py ├── plotting.py ├── stft.py ├── train.py ├── utils.py ├── validation.py └── writer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Seungwon Park 박승원 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MelGAN 2 | Unofficial PyTorch implementation of [MelGAN vocoder](https://arxiv.org/abs/1910.06711) 3 | 4 | ## Key Features 5 | 6 | - MelGAN is lighter, faster, and better at generalizing to unseen speakers than [WaveGlow](https://github.com/NVIDIA/waveglow). 7 | - This repository use identical mel-spectrogram function from [NVIDIA/tacotron2](https://github.com/NVIDIA/tacotron2), so this can be directly used to convert output from NVIDIA's tacotron2 into raw-audio. 8 | - Pretrained model on LJSpeech-1.1 via [PyTorch Hub](https://pytorch.org/hub). 9 | 10 | ![](./assets/gd.png) 11 | 12 | ## Prerequisites 13 | 14 | Tested on Python 3.6 15 | ```bash 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | Install pytorch-cpu 20 | ```bash 21 | pip3 install torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html 22 | ``` 23 | 24 | ## Prepare Dataset 25 | 26 | - Download dataset for training. This can be any wav files with sample rate 22050Hz. (e.g. LJSpeech was used in paper) 27 | - preprocess: `python preprocess.py -c config/default.yaml -d [data's root path]` 28 | - Edit configuration `yaml` file 29 | 30 | ## Train & Tensorboard 31 | 32 | - `python trainer.py -c [config yaml file] -n [name of the run]` 33 | - `cp config/default.yaml config/config.yaml` and then edit `config.yaml` 34 | - Write down the root path of train/validation files to 2nd/3rd line. 35 | - Each path should contain pairs of `*.wav` with corresponding (preprocessed) `*.mel` file. 36 | - The data loader parses list of files within the path recursively. 37 | - `tensorboard --logdir logs/` 38 | 39 | ## Pretrained model 40 | 41 | Try with Google Colab: TODO 42 | 43 | ```python 44 | import torch 45 | vocoder = torch.hub.load('seungwonpark/melgan', 'melgan') 46 | vocoder.eval() 47 | mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here 48 | 49 | if torch.cuda.is_available(): 50 | vocoder = vocoder.cuda() 51 | mel = mel.cuda() 52 | 53 | with torch.no_grad(): 54 | audio = vocoder.inference(mel) 55 | ``` 56 | 57 | ## Inference 58 | 59 | - `python inference.py -p [checkpoint path] -i [input mel path]` 60 | 61 | ## Results 62 | 63 | See audio samples at: http://swpark.me/melgan/. 64 | 65 | ![](./assets/lj-tensorboard.png) 66 | 67 | 68 | ## Implementation Authors 69 | 70 | - [Seungwon Park](http://swpark.me) @ MINDsLab Inc. (yyyyy@snu.ac.kr, swpark@mindslab.ai) 71 | - Myunchul Joe @ MINDsLab Inc. 72 | - [Rishikesh](https://github.com/rishikksh20) @ DeepSync Technologies Pvt Ltd. 73 | 74 | ## License 75 | 76 | BSD 3-Clause License. 77 | 78 | - [utils/stft.py](./utils/stft.py) by Prem Seetharaman (BSD 3-Clause License) 79 | - [datasets/mel2samp.py](./datasets/mel2samp.py) from https://github.com/NVIDIA/waveglow (BSD 3-Clause License) 80 | - [utils/hparams.py](./utils/hparams.py) from https://github.com/HarryVolek/PyTorch_Speaker_Verification (No License specified) 81 | 82 | ## Useful resources 83 | 84 | - [How to Train a GAN? Tips and tricks to make GANs work](https://github.com/soumith/ganhacks) by Soumith Chintala 85 | - [Official MelGAN implementation by original authors](https://github.com/descriptinc/melgan-neurips) 86 | -------------------------------------------------------------------------------- /assets/gd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rgzn-aiyun/melgan-cpu/668457534a4ee70ed38a16e8bf9a89d06c187aec/assets/gd.png -------------------------------------------------------------------------------- /assets/lj-tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rgzn-aiyun/melgan-cpu/668457534a4ee70ed38a16e8bf9a89d06c187aec/assets/lj-tensorboard.png -------------------------------------------------------------------------------- /config/default.yaml: -------------------------------------------------------------------------------- 1 | data: # root path of train/validation data (either relative/absoulte path is ok) 2 | train: '' 3 | validation: '' 4 | --- 5 | train: 6 | rep_discriminator: 1 7 | num_workers: 32 8 | batch_size: 16 9 | optimizer: 'adam' 10 | adam: 11 | lr: 0.0001 12 | beta1: 0.5 13 | beta2: 0.9 14 | --- 15 | audio: 16 | n_mel_channels: 80 17 | segment_length: 16000 18 | pad_short: 2000 19 | filter_length: 1024 20 | hop_length: 256 # WARNING: this can't be changed. 21 | win_length: 1024 22 | sampling_rate: 22050 23 | mel_fmin: 0.0 24 | mel_fmax: 8000.0 25 | --- 26 | model: 27 | feat_match: 10.0 28 | --- 29 | log: 30 | summary_interval: 1 31 | validation_interval: 5 32 | save_interval: 25 33 | chkpt_dir: 'chkpt' 34 | log_dir: 'logs' 35 | -------------------------------------------------------------------------------- /datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import random 5 | import numpy as np 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | from utils.utils import read_wav_np 9 | 10 | 11 | def create_dataloader(hp, args, train): 12 | dataset = MelFromDisk(hp, args, train) 13 | 14 | if train: 15 | return DataLoader(dataset=dataset, batch_size=hp.train.batch_size, shuffle=True, 16 | num_workers=hp.train.num_workers, pin_memory=True, drop_last=True) 17 | else: 18 | return DataLoader(dataset=dataset, batch_size=1, shuffle=False, 19 | num_workers=hp.train.num_workers, pin_memory=True, drop_last=False) 20 | 21 | 22 | class MelFromDisk(Dataset): 23 | def __init__(self, hp, args, train): 24 | self.hp = hp 25 | self.args = args 26 | self.train = train 27 | self.path = hp.data.train if train else hp.data.validation 28 | self.wav_list = glob.glob(os.path.join(self.path, '**', '*.wav'), recursive=True) 29 | self.mel_segment_length = hp.audio.segment_length // hp.audio.hop_length + 2 30 | self.mapping = [i for i in range(len(self.wav_list))] 31 | 32 | def __len__(self): 33 | return len(self.wav_list) 34 | 35 | def __getitem__(self, idx): 36 | if self.train: 37 | idx1 = idx 38 | idx2 = self.mapping[idx1] 39 | return self.my_getitem(idx1), self.my_getitem(idx2) 40 | else: 41 | return self.my_getitem(idx) 42 | 43 | def shuffle_mapping(self): 44 | random.shuffle(self.mapping) 45 | 46 | def my_getitem(self, idx): 47 | wavpath = self.wav_list[idx] 48 | melpath = wavpath.replace('.wav', '.mel') 49 | sr, audio = read_wav_np(wavpath) 50 | if len(audio) < self.hp.audio.segment_length + self.hp.audio.pad_short: 51 | audio = np.pad(audio, (0, self.hp.audio.segment_length + self.hp.audio.pad_short - len(audio)), \ 52 | mode='constant', constant_values=0.0) 53 | 54 | audio = torch.from_numpy(audio).unsqueeze(0) 55 | mel = torch.load(melpath).squeeze(0) 56 | 57 | if self.train: 58 | max_mel_start = mel.size(1) - self.mel_segment_length 59 | mel_start = random.randint(0, max_mel_start) 60 | mel_end = mel_start + self.mel_segment_length 61 | mel = mel[:, mel_start:mel_end] 62 | 63 | audio_start = mel_start * self.hp.audio.hop_length 64 | audio = audio[:, audio_start:audio_start+self.hp.audio.segment_length] 65 | 66 | audio = audio + (1/32768) * torch.randn_like(audio) 67 | return mel, audio 68 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch'] 2 | import torch 3 | from model.generator import Generator 4 | 5 | model_params = { 6 | 'nvidia_tacotron2_LJ11_epoch3200': { 7 | 'mel_channel': 80, 8 | 'model_url': 'https://github.com/seungwonpark/melgan/releases/download/v0.2-alpha/nvidia_tacotron2_LJ11_epoch3200_v02.pt', 9 | }, 10 | } 11 | 12 | 13 | def melgan(model_name='nvidia_tacotron2_LJ11_epoch3200', pretrained=True, progress=True): 14 | params = model_params[model_name] 15 | model = Generator(params['mel_channel']) 16 | 17 | if pretrained: 18 | state_dict = torch.hub.load_state_dict_from_url(params['model_url'], 19 | progress=progress) 20 | model.load_state_dict(state_dict['model_g']) 21 | 22 | model.eval(inference=True) 23 | 24 | return model 25 | 26 | 27 | if __name__ == '__main__': 28 | vocoder = torch.hub.load('seungwonpark/melgan', 'melgan') 29 | mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here 30 | 31 | print('Input mel-spectrogram shape: {}'.format(mel.shape)) 32 | 33 | if torch.cuda.is_available(): 34 | print('Moving data & model to GPU') 35 | vocoder = vocoder.cuda() 36 | mel = mel.cuda() 37 | 38 | with torch.no_grad(): 39 | audio = vocoder.inference(mel) 40 | 41 | print('Output audio shape: {}'.format(audio.shape)) 42 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import tqdm 4 | import torch 5 | import argparse 6 | from scipy.io.wavfile import write 7 | 8 | from model.generator import Generator 9 | from utils.hparams import HParam, load_hparam_str 10 | 11 | MAX_WAV_VALUE = 32768.0 12 | 13 | def main(args): 14 | checkpoint = torch.load(args.checkpoint_path, map_location=torch.device('cpu')) 15 | if args.config is not None: 16 | hp = HParam(args.config) 17 | print(hp) 18 | else: 19 | hp = load_hparam_str(checkpoint['hp_str']) 20 | print(hp) 21 | 22 | model = Generator(hp.audio.n_mel_channels) 23 | model.load_state_dict(checkpoint['model_g']) 24 | model.eval(inference=False) 25 | 26 | with torch.no_grad(): 27 | for melpath in tqdm.tqdm(glob.glob(os.path.join(args.input_folder, '*.mel'))): 28 | mel = torch.load(melpath) 29 | if len(mel.shape) == 2: 30 | mel = mel.unsqueeze(0) 31 | 32 | audio = model.inference(mel) 33 | audio = audio.cpu().detach().numpy() 34 | 35 | out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch']) 36 | write(out_path, hp.audio.sampling_rate, audio) 37 | 38 | 39 | if __name__ == '__main__': 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('-c', '--config', type=str, default=None, 42 | help="yaml file for config. will use hp_str from checkpoint if not given.") 43 | parser.add_argument('-p', '--checkpoint_path', type=str, required=True, 44 | help="path of checkpoint pt file for evaluation") 45 | parser.add_argument('-i', '--input_folder', type=str, required=True, 46 | help="directory of mel-spectrograms to invert into raw audio. ") 47 | args = parser.parse_args() 48 | 49 | main(args) 50 | -------------------------------------------------------------------------------- /model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Discriminator(nn.Module): 7 | def __init__(self): 8 | super(Discriminator, self).__init__() 9 | 10 | self.discriminator = nn.ModuleList([ 11 | nn.Sequential( 12 | nn.ReflectionPad1d(7), 13 | nn.utils.weight_norm(nn.Conv1d(1, 16, kernel_size=15, stride=1)), 14 | nn.LeakyReLU(0.2, inplace=True), 15 | ), 16 | nn.Sequential( 17 | nn.utils.weight_norm(nn.Conv1d(16, 64, kernel_size=41, stride=4, padding=20, groups=4)), 18 | nn.LeakyReLU(0.2, inplace=True), 19 | ), 20 | nn.Sequential( 21 | nn.utils.weight_norm(nn.Conv1d(64, 256, kernel_size=41, stride=4, padding=20, groups=16)), 22 | nn.LeakyReLU(0.2, inplace=True), 23 | ), 24 | nn.Sequential( 25 | nn.utils.weight_norm(nn.Conv1d(256, 1024, kernel_size=41, stride=4, padding=20, groups=64)), 26 | nn.LeakyReLU(0.2, inplace=True), 27 | ), 28 | nn.Sequential( 29 | nn.utils.weight_norm(nn.Conv1d(1024, 1024, kernel_size=41, stride=4, padding=20, groups=256)), 30 | nn.LeakyReLU(0.2, inplace=True), 31 | ), 32 | nn.Sequential( 33 | nn.utils.weight_norm(nn.Conv1d(1024, 1024, kernel_size=5, stride=1, padding=2)), 34 | nn.LeakyReLU(0.2, inplace=True), 35 | ), 36 | nn.utils.weight_norm(nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1)), 37 | ]) 38 | 39 | def forward(self, x): 40 | ''' 41 | returns: (list of 6 features, discriminator score) 42 | we directly predict score without last sigmoid function 43 | since we're using Least Squares GAN (https://arxiv.org/abs/1611.04076) 44 | ''' 45 | features = list() 46 | for module in self.discriminator: 47 | x = module(x) 48 | features.append(x) 49 | return features[:-1], features[-1] 50 | 51 | 52 | if __name__ == '__main__': 53 | model = Discriminator() 54 | 55 | x = torch.randn(3, 1, 22050) 56 | print(x.shape) 57 | 58 | features, score = model(x) 59 | for feat in features: 60 | print(feat.shape) 61 | print(score.shape) 62 | 63 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 64 | print(pytorch_total_params) -------------------------------------------------------------------------------- /model/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .res_stack import ResStack 6 | 7 | MAX_WAV_VALUE = 32768.0 8 | 9 | 10 | class Generator(nn.Module): 11 | def __init__(self, mel_channel): 12 | super(Generator, self).__init__() 13 | self.mel_channel = mel_channel 14 | 15 | self.generator = nn.Sequential( 16 | nn.ReflectionPad1d(3), 17 | nn.utils.weight_norm(nn.Conv1d(mel_channel, 512, kernel_size=7, stride=1)), 18 | 19 | nn.LeakyReLU(0.2), 20 | nn.utils.weight_norm(nn.ConvTranspose1d(512, 256, kernel_size=16, stride=8, padding=4)), 21 | 22 | ResStack(256), 23 | 24 | nn.LeakyReLU(0.2), 25 | nn.utils.weight_norm(nn.ConvTranspose1d(256, 128, kernel_size=16, stride=8, padding=4)), 26 | 27 | ResStack(128), 28 | 29 | nn.LeakyReLU(0.2), 30 | nn.utils.weight_norm(nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1)), 31 | 32 | ResStack(64), 33 | 34 | nn.LeakyReLU(0.2), 35 | nn.utils.weight_norm(nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1)), 36 | 37 | ResStack(32), 38 | 39 | nn.LeakyReLU(0.2), 40 | nn.ReflectionPad1d(3), 41 | nn.utils.weight_norm(nn.Conv1d(32, 1, kernel_size=7, stride=1)), 42 | nn.Tanh(), 43 | ) 44 | 45 | def forward(self, mel): 46 | mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram 47 | return self.generator(mel) 48 | 49 | def eval(self, inference=False): 50 | super(Generator, self).eval() 51 | 52 | # don't remove weight norm while validation in training loop 53 | if inference: 54 | self.remove_weight_norm() 55 | 56 | def remove_weight_norm(self): 57 | for idx, layer in enumerate(self.generator): 58 | if len(layer.state_dict()) != 0: 59 | try: 60 | nn.utils.remove_weight_norm(layer) 61 | except: 62 | layer.remove_weight_norm() 63 | 64 | def inference(self, mel): 65 | hop_length = 256 66 | # pad input mel with zeros to cut artifact 67 | # see https://github.com/seungwonpark/melgan/issues/8 68 | zero = torch.full((1, self.mel_channel, 10), -11.5129).to(mel.device) 69 | mel = torch.cat((mel, zero), dim=2) 70 | 71 | audio = self.forward(mel) 72 | audio = audio.squeeze() # collapse all dimension except time axis 73 | audio = audio[:-(hop_length*10)] 74 | audio = MAX_WAV_VALUE * audio 75 | audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1) 76 | audio = audio.short() 77 | 78 | return audio 79 | 80 | 81 | ''' 82 | to run this, fix 83 | from . import ResStack 84 | into 85 | from res_stack import ResStack 86 | ''' 87 | if __name__ == '__main__': 88 | model = Generator(80) 89 | 90 | x = torch.randn(3, 80, 10) 91 | print(x.shape) 92 | 93 | y = model(x) 94 | print(y.shape) 95 | assert y.shape == torch.Size([3, 1, 2560]) 96 | 97 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 98 | print(pytorch_total_params) 99 | -------------------------------------------------------------------------------- /model/identity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Identity(nn.Module): 7 | def __init__(self): 8 | super(Identity, self).__init__() 9 | 10 | def forward(self, x): 11 | return x 12 | 13 | -------------------------------------------------------------------------------- /model/multiscale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .discriminator import Discriminator 6 | from .identity import Identity 7 | 8 | 9 | class MultiScaleDiscriminator(nn.Module): 10 | def __init__(self): 11 | super(MultiScaleDiscriminator, self).__init__() 12 | 13 | self.discriminators = nn.ModuleList( 14 | [Discriminator() for _ in range(3)] 15 | ) 16 | 17 | self.pooling = nn.ModuleList( 18 | [Identity()] + 19 | [nn.AvgPool1d(kernel_size=4, stride=2, padding=1, count_include_pad=False) for _ in range(1, 3)] 20 | ) 21 | 22 | def forward(self, x): 23 | ret = list() 24 | 25 | for pool, disc in zip(self.pooling, self.discriminators): 26 | x = pool(x) 27 | ret.append(disc(x)) 28 | 29 | return ret # [(feat, score), (feat, score), (feat, score)] 30 | -------------------------------------------------------------------------------- /model/res_stack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class ResStack(nn.Module): 8 | def __init__(self, channel): 9 | super(ResStack, self).__init__() 10 | 11 | self.blocks = nn.ModuleList([ 12 | nn.Sequential( 13 | nn.LeakyReLU(0.2), 14 | nn.ReflectionPad1d(3**i), 15 | nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=3, dilation=3**i)), 16 | nn.LeakyReLU(0.2), 17 | nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)), 18 | ) 19 | for i in range(3) 20 | ]) 21 | 22 | self.shortcuts = nn.ModuleList([ 23 | nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)) 24 | for i in range(3) 25 | ]) 26 | 27 | def forward(self, x): 28 | for block, shortcut in zip(self.blocks, self.shortcuts): 29 | x = shortcut(x) + block(x) 30 | return x 31 | 32 | def remove_weight_norm(self): 33 | for block, shortcut in zip(self.blocks, self.shortcuts): 34 | nn.utils.remove_weight_norm(block[2]) 35 | nn.utils.remove_weight_norm(block[4]) 36 | nn.utils.remove_weight_norm(shortcut) 37 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import tqdm 4 | import torch 5 | import argparse 6 | import numpy as np 7 | 8 | from utils.stft import TacotronSTFT 9 | from utils.hparams import HParam 10 | from utils.utils import read_wav_np 11 | 12 | 13 | def main(hp, args): 14 | stft = TacotronSTFT(filter_length=hp.audio.filter_length, 15 | hop_length=hp.audio.hop_length, 16 | win_length=hp.audio.win_length, 17 | n_mel_channels=hp.audio.n_mel_channels, 18 | sampling_rate=hp.audio.sampling_rate, 19 | mel_fmin=hp.audio.mel_fmin, 20 | mel_fmax=hp.audio.mel_fmax) 21 | 22 | wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True) 23 | 24 | for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'): 25 | sr, wav = read_wav_np(wavpath) 26 | assert sr == hp.audio.sampling_rate, \ 27 | "sample rate mismatch. expected %d, got %d at %s" % \ 28 | (hp.audio.sampling_rate, sr, wavpath) 29 | 30 | if len(wav) < hp.audio.segment_length + hp.audio.pad_short: 31 | wav = np.pad(wav, (0, hp.audio.segment_length + hp.audio.pad_short - len(wav)), \ 32 | mode='constant', constant_values=0.0) 33 | 34 | wav = torch.from_numpy(wav).unsqueeze(0) 35 | mel = stft.mel_spectrogram(wav) 36 | 37 | melpath = wavpath.replace('.wav', '.mel') 38 | torch.save(mel, melpath) 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('-c', '--config', type=str, required=True, 44 | help="yaml file for config.") 45 | parser.add_argument('-d', '--data_path', type=str, required=True, 46 | help="root directory of wav files") 47 | args = parser.parse_args() 48 | hp = HParam(args.config) 49 | 50 | main(hp, args) 51 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa 2 | matplotlib 3 | numpy 4 | scipy 5 | tensorboardX 6 | tqdm 7 | pillow 8 | pyyaml 9 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import argparse 5 | 6 | from utils.train import train 7 | from utils.hparams import HParam 8 | from utils.writer import MyWriter 9 | from datasets.dataloader import create_dataloader 10 | 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-c', '--config', type=str, required=True, 15 | help="yaml file for configuration") 16 | parser.add_argument('-p', '--checkpoint_path', type=str, default=None, 17 | help="path of checkpoint pt file to resume training") 18 | parser.add_argument('-n', '--name', type=str, required=True, 19 | help="name of the model for logging, saving checkpoint") 20 | args = parser.parse_args() 21 | 22 | hp = HParam(args.config) 23 | with open(args.config, 'r') as f: 24 | hp_str = ''.join(f.readlines()) 25 | 26 | pt_dir = os.path.join(hp.log.chkpt_dir, args.name) 27 | log_dir = os.path.join(hp.log.log_dir, args.name) 28 | os.makedirs(pt_dir, exist_ok=True) 29 | os.makedirs(log_dir, exist_ok=True) 30 | 31 | logging.basicConfig( 32 | level=logging.INFO, 33 | format='%(asctime)s - %(levelname)s - %(message)s', 34 | handlers=[ 35 | logging.FileHandler(os.path.join(log_dir, 36 | '%s-%d.log' % (args.name, time.time()))), 37 | logging.StreamHandler() 38 | ] 39 | ) 40 | logger = logging.getLogger() 41 | 42 | writer = MyWriter(hp, log_dir) 43 | 44 | assert hp.audio.hop_length == 256, \ 45 | 'hp.audio.hop_length must be equal to 256, got %d' % hp.audio.hop_length 46 | assert hp.data.train != '' and hp.data.validation != '', \ 47 | 'hp.data.train and hp.data.validation can\'t be empty: please fix %s' % args.config 48 | 49 | trainloader = create_dataloader(hp, args, True) 50 | valloader = create_dataloader(hp, args, False) 51 | 52 | train(args, pt_dir, args.checkpoint_path, trainloader, valloader, writer, logger, hp, hp_str) 53 | -------------------------------------------------------------------------------- /utils/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | 7 | def window_sumsquare(window, n_frames, hop_length=200, win_length=800, 8 | n_fft=800, dtype=np.float32, norm=None): 9 | """ 10 | # from librosa 0.6 11 | Compute the sum-square envelope of a window function at a given hop length. 12 | 13 | This is used to estimate modulation effects induced by windowing 14 | observations in short-time fourier transforms. 15 | 16 | Parameters 17 | ---------- 18 | window : string, tuple, number, callable, or list-like 19 | Window specification, as in `get_window` 20 | 21 | n_frames : int > 0 22 | The number of analysis frames 23 | 24 | hop_length : int > 0 25 | The number of samples to advance between frames 26 | 27 | win_length : [optional] 28 | The length of the window function. By default, this matches `n_fft`. 29 | 30 | n_fft : int > 0 31 | The length of each analysis frame. 32 | 33 | dtype : np.dtype 34 | The data type of the output 35 | 36 | Returns 37 | ------- 38 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 39 | The sum-squared envelope of the window function 40 | """ 41 | if win_length is None: 42 | win_length = n_fft 43 | 44 | n = n_fft + hop_length * (n_frames - 1) 45 | x = np.zeros(n, dtype=dtype) 46 | 47 | # Compute the squared window at the desired length 48 | win_sq = get_window(window, win_length, fftbins=True) 49 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 50 | win_sq = librosa_util.pad_center(win_sq, n_fft) 51 | 52 | # Fill the envelope 53 | for i in range(n_frames): 54 | sample = i * hop_length 55 | x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 56 | return x 57 | 58 | 59 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 60 | """ 61 | PARAMS 62 | ------ 63 | magnitudes: spectrogram magnitudes 64 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 65 | """ 66 | 67 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 68 | angles = angles.astype(np.float32) 69 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 70 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 71 | 72 | for i in range(n_iters): 73 | _, angles = stft_fn.transform(signal) 74 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 75 | return signal 76 | 77 | 78 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 79 | """ 80 | PARAMS 81 | ------ 82 | C: compression factor 83 | """ 84 | return torch.log(torch.clamp(x, min=clip_val) * C) 85 | 86 | 87 | def dynamic_range_decompression(x, C=1): 88 | """ 89 | PARAMS 90 | ------ 91 | C: compression factor used to compress 92 | """ 93 | return torch.exp(x) / C 94 | -------------------------------------------------------------------------------- /utils/hparams.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/HarryVolek/PyTorch_Speaker_Verification 2 | 3 | import os 4 | import yaml 5 | 6 | 7 | def load_hparam_str(hp_str): 8 | path = 'temp-restore.yaml' 9 | with open(path, 'w') as f: 10 | f.write(hp_str) 11 | ret = HParam(path) 12 | os.remove(path) 13 | return ret 14 | 15 | 16 | def load_hparam(filename): 17 | stream = open(filename, 'r') 18 | docs = yaml.load_all(stream, Loader=yaml.Loader) 19 | hparam_dict = dict() 20 | for doc in docs: 21 | for k, v in doc.items(): 22 | hparam_dict[k] = v 23 | return hparam_dict 24 | 25 | 26 | def merge_dict(user, default): 27 | if isinstance(user, dict) and isinstance(default, dict): 28 | for k, v in default.items(): 29 | if k not in user: 30 | user[k] = v 31 | else: 32 | user[k] = merge_dict(user[k], v) 33 | return user 34 | 35 | 36 | class Dotdict(dict): 37 | """ 38 | a dictionary that supports dot notation 39 | as well as dictionary access notation 40 | usage: d = DotDict() or d = DotDict({'val1':'first'}) 41 | set attributes: d.val2 = 'second' or d['val2'] = 'second' 42 | get attributes: d.val2 or d['val2'] 43 | """ 44 | __getattr__ = dict.__getitem__ 45 | __setattr__ = dict.__setitem__ 46 | __delattr__ = dict.__delitem__ 47 | 48 | def __init__(self, dct=None): 49 | dct = dict() if not dct else dct 50 | for key, value in dct.items(): 51 | if hasattr(value, 'keys'): 52 | value = Dotdict(value) 53 | self[key] = value 54 | 55 | 56 | class HParam(Dotdict): 57 | 58 | def __init__(self, file): 59 | super(Dotdict, self).__init__() 60 | hp_dict = load_hparam(file) 61 | hp_dotdict = Dotdict(hp_dict) 62 | for k, v in hp_dotdict.items(): 63 | setattr(self, k, v) 64 | 65 | __getattr__ = Dotdict.__getitem__ 66 | __setattr__ = Dotdict.__setitem__ 67 | __delattr__ = Dotdict.__delitem__ 68 | -------------------------------------------------------------------------------- /utils/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import matplotlib.pylab as plt 4 | import numpy as np 5 | 6 | 7 | def save_figure_to_numpy(fig): 8 | # save it to a numpy array. 9 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 10 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 11 | data = np.transpose(data, (2, 0, 1)) 12 | return data 13 | 14 | 15 | def plot_waveform_to_numpy(waveform): 16 | fig, ax = plt.subplots(figsize=(12, 3)) 17 | ax.plot() 18 | ax.plot(range(len(waveform)), waveform, 19 | linewidth=0.1, alpha=0.7, color='blue') 20 | 21 | plt.xlabel("Samples") 22 | plt.ylabel("Amplitude") 23 | plt.ylim(-1, 1) 24 | plt.tight_layout() 25 | 26 | fig.canvas.draw() 27 | data = save_figure_to_numpy(fig) 28 | plt.close() 29 | return data 30 | -------------------------------------------------------------------------------- /utils/stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) 2017, Prem Seetharaman 5 | All rights reserved. 6 | 7 | * Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, this 14 | list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from this 19 | software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | """ 32 | 33 | import torch 34 | import numpy as np 35 | import torch.nn.functional as F 36 | from torch.autograd import Variable 37 | from scipy.signal import get_window 38 | from librosa.util import pad_center, tiny 39 | from .audio_processing import window_sumsquare, dynamic_range_compression, dynamic_range_decompression 40 | from librosa.filters import mel as librosa_mel_fn 41 | 42 | 43 | class STFT(torch.nn.Module): 44 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 45 | def __init__(self, filter_length=800, hop_length=200, win_length=800, 46 | window='hann'): 47 | super(STFT, self).__init__() 48 | self.filter_length = filter_length 49 | self.hop_length = hop_length 50 | self.win_length = win_length 51 | self.window = window 52 | self.forward_transform = None 53 | scale = self.filter_length / self.hop_length 54 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 55 | 56 | cutoff = int((self.filter_length / 2 + 1)) 57 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 58 | np.imag(fourier_basis[:cutoff, :])]) 59 | 60 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 61 | inverse_basis = torch.FloatTensor( 62 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 63 | 64 | if window is not None: 65 | assert(filter_length >= win_length) 66 | # get window and zero center pad it to filter_length 67 | fft_window = get_window(window, win_length, fftbins=True) 68 | fft_window = pad_center(fft_window, filter_length) 69 | fft_window = torch.from_numpy(fft_window).float() 70 | 71 | # window the bases 72 | forward_basis *= fft_window 73 | inverse_basis *= fft_window 74 | 75 | self.register_buffer('forward_basis', forward_basis.float()) 76 | self.register_buffer('inverse_basis', inverse_basis.float()) 77 | 78 | def transform(self, input_data): 79 | num_batches = input_data.size(0) 80 | num_samples = input_data.size(1) 81 | 82 | self.num_samples = num_samples 83 | 84 | # similar to librosa, reflect-pad the input 85 | input_data = input_data.view(num_batches, 1, num_samples) 86 | input_data = F.pad( 87 | input_data.unsqueeze(1), 88 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 89 | mode='reflect') 90 | input_data = input_data.squeeze(1) 91 | 92 | # https://github.com/NVIDIA/tacotron2/issues/125 93 | forward_transform = F.conv1d( 94 | input_data.cuda(), 95 | Variable(self.forward_basis, requires_grad=False).cuda(), 96 | stride=self.hop_length, 97 | padding=0).cpu() 98 | 99 | cutoff = int((self.filter_length / 2) + 1) 100 | real_part = forward_transform[:, :cutoff, :] 101 | imag_part = forward_transform[:, cutoff:, :] 102 | 103 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 104 | phase = torch.autograd.Variable( 105 | torch.atan2(imag_part.data, real_part.data)) 106 | 107 | return magnitude, phase 108 | 109 | def inverse(self, magnitude, phase): 110 | recombine_magnitude_phase = torch.cat( 111 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 112 | 113 | inverse_transform = F.conv_transpose1d( 114 | recombine_magnitude_phase, 115 | Variable(self.inverse_basis, requires_grad=False), 116 | stride=self.hop_length, 117 | padding=0) 118 | 119 | if self.window is not None: 120 | window_sum = window_sumsquare( 121 | self.window, magnitude.size(-1), hop_length=self.hop_length, 122 | win_length=self.win_length, n_fft=self.filter_length, 123 | dtype=np.float32) 124 | # remove modulation effects 125 | approx_nonzero_indices = torch.from_numpy( 126 | np.where(window_sum > tiny(window_sum))[0]) 127 | window_sum = torch.autograd.Variable( 128 | torch.from_numpy(window_sum), requires_grad=False) 129 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 130 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 131 | 132 | # scale by hop ratio 133 | inverse_transform *= float(self.filter_length) / self.hop_length 134 | 135 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 136 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] 137 | 138 | return inverse_transform 139 | 140 | def forward(self, input_data): 141 | self.magnitude, self.phase = self.transform(input_data) 142 | reconstruction = self.inverse(self.magnitude, self.phase) 143 | return reconstruction 144 | 145 | 146 | class TacotronSTFT(torch.nn.Module): 147 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 148 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 149 | mel_fmax=None): 150 | super(TacotronSTFT, self).__init__() 151 | self.n_mel_channels = n_mel_channels 152 | self.sampling_rate = sampling_rate 153 | self.stft_fn = STFT(filter_length, hop_length, win_length) 154 | mel_basis = librosa_mel_fn( 155 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 156 | mel_basis = torch.from_numpy(mel_basis).float() 157 | self.register_buffer('mel_basis', mel_basis) 158 | 159 | def spectral_normalize(self, magnitudes): 160 | output = dynamic_range_compression(magnitudes) 161 | return output 162 | 163 | def spectral_de_normalize(self, magnitudes): 164 | output = dynamic_range_decompression(magnitudes) 165 | return output 166 | 167 | def mel_spectrogram(self, y): 168 | """Computes mel-spectrograms from a batch of waves 169 | PARAMS 170 | ------ 171 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 172 | 173 | RETURNS 174 | ------- 175 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 176 | """ 177 | assert(torch.min(y.data) >= -1) 178 | assert(torch.max(y.data) <= 1) 179 | 180 | magnitudes, phases = self.stft_fn.transform(y) 181 | magnitudes = magnitudes.data 182 | mel_output = torch.matmul(self.mel_basis, magnitudes) 183 | mel_output = self.spectral_normalize(mel_output) 184 | return mel_output 185 | -------------------------------------------------------------------------------- /utils/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import tqdm 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import itertools 8 | import traceback 9 | 10 | from model.generator import Generator 11 | from model.multiscale import MultiScaleDiscriminator 12 | from .utils import get_commit_hash 13 | from .validation import validate 14 | 15 | 16 | def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, hp_str): 17 | model_g = Generator(hp.audio.n_mel_channels).cuda() 18 | model_d = MultiScaleDiscriminator().cuda() 19 | 20 | optim_g = torch.optim.Adam(model_g.parameters(), 21 | lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2)) 22 | optim_d = torch.optim.Adam(model_d.parameters(), 23 | lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2)) 24 | 25 | githash = get_commit_hash() 26 | 27 | init_epoch = -1 28 | step = 0 29 | 30 | if chkpt_path is not None: 31 | logger.info("Resuming from checkpoint: %s" % chkpt_path) 32 | checkpoint = torch.load(chkpt_path) 33 | model_g.load_state_dict(checkpoint['model_g']) 34 | model_d.load_state_dict(checkpoint['model_d']) 35 | optim_g.load_state_dict(checkpoint['optim_g']) 36 | optim_d.load_state_dict(checkpoint['optim_d']) 37 | step = checkpoint['step'] 38 | init_epoch = checkpoint['epoch'] 39 | 40 | if hp_str != checkpoint['hp_str']: 41 | logger.warning("New hparams is different from checkpoint. Will use new.") 42 | 43 | if githash != checkpoint['githash']: 44 | logger.warning("Code might be different: git hash is different.") 45 | logger.warning("%s -> %s" % (checkpoint['githash'], githash)) 46 | 47 | else: 48 | logger.info("Starting new training run.") 49 | 50 | # this accelerates training when the size of minibatch is always consistent. 51 | # if not consistent, it'll horribly slow down. 52 | torch.backends.cudnn.benchmark = True 53 | 54 | try: 55 | model_g.train() 56 | model_d.train() 57 | for epoch in itertools.count(init_epoch+1): 58 | if epoch % hp.log.validation_interval == 0: 59 | with torch.no_grad(): 60 | validate(hp, args, model_g, model_d, valloader, writer, step) 61 | 62 | trainloader.dataset.shuffle_mapping() 63 | loader = tqdm.tqdm(trainloader, desc='Loading train data') 64 | for (melG, audioG), (melD, audioD) in loader: 65 | melG = melG.cuda() 66 | audioG = audioG.cuda() 67 | melD = melD.cuda() 68 | audioD = audioD.cuda() 69 | 70 | # generator 71 | optim_g.zero_grad() 72 | fake_audio = model_g(melG)[:, :, :hp.audio.segment_length] 73 | disc_fake = model_d(fake_audio) 74 | disc_real = model_d(audioG) 75 | loss_g = 0.0 76 | for (feats_fake, score_fake), (feats_real, _) in zip(disc_fake, disc_real): 77 | loss_g += torch.mean(torch.sum(torch.pow(score_fake - 1.0, 2), dim=[1, 2])) 78 | for feat_f, feat_r in zip(feats_fake, feats_real): 79 | loss_g += hp.model.feat_match * torch.mean(torch.abs(feat_f - feat_r)) 80 | 81 | loss_g.backward() 82 | optim_g.step() 83 | 84 | # discriminator 85 | fake_audio = model_g(melD)[:, :, :hp.audio.segment_length] 86 | fake_audio = fake_audio.detach() 87 | loss_d_sum = 0.0 88 | for _ in range(hp.train.rep_discriminator): 89 | optim_d.zero_grad() 90 | disc_fake = model_d(fake_audio) 91 | disc_real = model_d(audioD) 92 | loss_d = 0.0 93 | for (_, score_fake), (_, score_real) in zip(disc_fake, disc_real): 94 | loss_d += torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2])) 95 | loss_d += torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2])) 96 | 97 | loss_d.backward() 98 | optim_d.step() 99 | loss_d_sum += loss_d 100 | 101 | step += 1 102 | # logging 103 | loss_g = loss_g.item() 104 | loss_d_avg = loss_d_sum / hp.train.rep_discriminator 105 | loss_d_avg = loss_d_avg.item() 106 | if any([loss_g > 1e8, math.isnan(loss_g), loss_d_avg > 1e8, math.isnan(loss_d_avg)]): 107 | logger.error("loss_g %.01f loss_d_avg %.01f at step %d!" % (loss_g, loss_d_avg, step)) 108 | raise Exception("Loss exploded") 109 | 110 | if step % hp.log.summary_interval == 0: 111 | writer.log_training(loss_g, loss_d_avg, step) 112 | loader.set_description("g %.04f d %.04f | step %d" % (loss_g, loss_d_avg, step)) 113 | 114 | if epoch % hp.log.save_interval == 0: 115 | save_path = os.path.join(pt_dir, '%s_%s_%04d.pt' 116 | % (args.name, githash, epoch)) 117 | torch.save({ 118 | 'model_g': model_g.state_dict(), 119 | 'model_d': model_d.state_dict(), 120 | 'optim_g': optim_g.state_dict(), 121 | 'optim_d': optim_d.state_dict(), 122 | 'step': step, 123 | 'epoch': epoch, 124 | 'hp_str': hp_str, 125 | 'githash': githash, 126 | }, save_path) 127 | logger.info("Saved checkpoint to: %s" % save_path) 128 | 129 | except Exception as e: 130 | logger.info("Exiting due to exception: %s" % e) 131 | traceback.print_exc() 132 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import subprocess 3 | import numpy as np 4 | from scipy.io.wavfile import read 5 | 6 | 7 | def get_commit_hash(): 8 | message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) 9 | return message.strip().decode('utf-8') 10 | 11 | def read_wav_np(path): 12 | sr, wav = read(path) 13 | 14 | if len(wav.shape) == 2: 15 | wav = wav[:, 0] 16 | 17 | if wav.dtype == np.int16: 18 | wav = wav / 32768.0 19 | elif wav.dtype == np.int32: 20 | wav = wav / 2147483648.0 21 | elif wav.dtype == np.uint8: 22 | wav = (wav - 128) / 128.0 23 | 24 | wav = wav.astype(np.float32) 25 | 26 | return sr, wav 27 | -------------------------------------------------------------------------------- /utils/validation.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import torch 3 | 4 | 5 | def validate(hp, args, generator, discriminator, valloader, writer, step): 6 | generator.eval() 7 | discriminator.eval() 8 | torch.backends.cudnn.benchmark = False 9 | 10 | loader = tqdm.tqdm(valloader, desc='Validation loop') 11 | loss_g_sum = 0.0 12 | loss_d_sum = 0.0 13 | for mel, audio in loader: 14 | mel = mel.cuda() 15 | audio = audio.cuda() 16 | 17 | # generator 18 | fake_audio = generator(mel) 19 | disc_fake = discriminator(fake_audio[:, :, :audio.size(2)]) 20 | disc_real = discriminator(audio) 21 | loss_g = 0.0 22 | loss_d = 0.0 23 | for (feats_fake, score_fake), (feats_real, score_real) in zip(disc_fake, disc_real): 24 | loss_g += torch.mean(torch.sum(torch.pow(score_fake - 1.0, 2), dim=[1, 2])) 25 | for feat_f, feat_r in zip(feats_fake, feats_real): 26 | loss_g += hp.model.feat_match * torch.mean(torch.abs(feat_f - feat_r)) 27 | loss_d += torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2])) 28 | loss_d += torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2])) 29 | 30 | loss_g_sum += loss_g.item() 31 | loss_d_sum += loss_d.item() 32 | 33 | loss_g_avg = loss_g_sum / len(valloader.dataset) 34 | loss_d_avg = loss_d_sum / len(valloader.dataset) 35 | 36 | audio = audio[0][0].cpu().detach().numpy() 37 | fake_audio = fake_audio[0][0].cpu().detach().numpy() 38 | 39 | writer.log_validation(loss_g_avg, loss_d_avg, generator, discriminator, audio, fake_audio, step) 40 | 41 | torch.backends.cudnn.benchmark = True 42 | -------------------------------------------------------------------------------- /utils/writer.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | 3 | from .plotting import plot_waveform_to_numpy 4 | 5 | 6 | class MyWriter(SummaryWriter): 7 | def __init__(self, hp, logdir): 8 | super(MyWriter, self).__init__(logdir) 9 | self.sample_rate = hp.audio.sampling_rate 10 | self.is_first = True 11 | 12 | def log_training(self, g_loss, d_loss, step): 13 | self.add_scalar('train.g_loss', g_loss, step) 14 | self.add_scalar('train.d_loss', d_loss, step) 15 | 16 | def log_validation(self, g_loss, d_loss, generator, discriminator, target, prediction, step): 17 | self.add_scalar('validation.g_loss', g_loss, step) 18 | self.add_scalar('validation.d_loss', d_loss, step) 19 | 20 | self.add_audio('raw_audio_predicted', prediction, step, self.sample_rate) 21 | self.add_image('waveform_predicted', plot_waveform_to_numpy(prediction), step) 22 | 23 | self.log_histogram(generator, step) 24 | self.log_histogram(discriminator, step) 25 | 26 | if self.is_first: 27 | self.add_audio('raw_audio_target', target, step, self.sample_rate) 28 | self.add_image('waveform_target', plot_waveform_to_numpy(target), step) 29 | self.is_first = False 30 | 31 | def log_histogram(self, model, step): 32 | for tag, value in model.named_parameters(): 33 | self.add_histogram(tag.replace('.', '/'), value.cpu().detach().numpy(), step) 34 | --------------------------------------------------------------------------------