├── .gitignore ├── LICENSE ├── README.md ├── generate.py ├── hifigan ├── __init__.py ├── dataset.py ├── discriminator.py ├── generator.py └── utils.py ├── hubconf.py ├── requirements.txt ├── resample.py ├── train.py └── vocoder.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # VSCode project settings 114 | .vscode 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Benjamin van Niekerk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HiFi-GAN 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2111.02392) 4 | [![demo](https://img.shields.io/static/v1?message=Audio%20Samples&logo=Github&labelColor=grey&color=blue&logoColor=white&label=%20&style=flat)](https://bshall.github.io/soft-vc/) 5 | [![colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bshall/soft-vc/blob/main/soft-vc-demo.ipynb) 6 | 7 | Training and inference scripts for the vocoder models in [A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion](https://ieeexplore.ieee.org/abstract/document/9746484). For more details see [soft-vc](https://github.com/bshall/soft-vc). Audio samples can be found [here](https://bshall.github.io/soft-vc/). Colab demo can be found [here](https://colab.research.google.com/github/bshall/soft-vc/blob/main/soft-vc-demo.ipynb). 8 | 9 |
10 | Soft-VC 12 |
13 |
14 | 15 | Fig 1: Architecture of the voice conversion system. a) The discrete content encoder clusters audio features to produce a sequence of discrete speech units. b) The soft content encoder is trained to predict the discrete units. The acoustic model transforms the discrete/soft speech units into a target spectrogram. The vocoder converts the spectrogram into an audio waveform. 16 | 17 |
18 | 19 | ## Example Usage 20 | 21 | ### Programmatic Usage 22 | 23 | ```python 24 | import torch 25 | import numpy as np 26 | 27 | # Load checkpoint 28 | hifigan = torch.hub.load("bshall/hifigan:main", "hifigan_hubert_soft").cuda() 29 | # Load mel-spectrogram 30 | mel = torch.from_numpy(np.load("path/to/mel")).unsqueeze(0).cuda() 31 | # Generate 32 | wav, sr = hifigan.generate(mel) 33 | ``` 34 | 35 | ### Script-Based Usage 36 | 37 | ``` 38 | usage: generate.py [-h] {soft,discrete,base} in-dir out-dir 39 | 40 | Generate audio for a directory of mel-spectrogams using HiFi-GAN. 41 | 42 | positional arguments: 43 | {soft,discrete,base} available models (HuBERT-Soft, HuBERT-Discrete, or 44 | Base). 45 | in-dir path to input directory containing the mel- 46 | spectrograms. 47 | out-dir path to output directory. 48 | 49 | optional arguments: 50 | -h, --help show this help message and exit 51 | ``` 52 | 53 | ## Training 54 | 55 | ### Step 1: Dataset Preparation 56 | 57 | Download and extract the [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) dataset. The training script expects the following tree structure for the dataset directory: 58 | 59 | ``` 60 | └───wavs 61 | ├───dev 62 | │ ├───LJ001-0001.wav 63 | │ ├───... 64 | │ └───LJ050-0278.wav 65 | └───train 66 | ├───LJ002-0332.wav 67 | ├───... 68 | └───LJ047-0007.wav 69 | ``` 70 | 71 | The `train` and `dev` directories should contain the training and validation splits respectively. The splits used for the paper can be found [here](https://github.com/bshall/hifigan/releases/tag/v0.1). 72 | 73 | ### Step 2: Resample the Audio 74 | 75 | Resample the audio to 16kHz using the `resample.py` script: 76 | 77 | ``` 78 | usage: resample.py [-h] [--sample-rate SAMPLE_RATE] in-dir out-dir 79 | 80 | Resample an audio dataset. 81 | 82 | positional arguments: 83 | in-dir path to the dataset directory. 84 | out-dir path to the output directory. 85 | 86 | optional arguments: 87 | -h, --help show this help message and exit 88 | --sample-rate SAMPLE_RATE 89 | target sample rate (default 16kHz) 90 | ``` 91 | 92 | for example: 93 | 94 | ``` 95 | python reample.py path/to/LJSpeech-1.1/ path/to/LJSpeech-Resampled/ 96 | ``` 97 | 98 | ### Step 3: Train HifiGAN 99 | 100 | ``` 101 | usage: train.py [-h] [--resume RESUME] [--finetune] dataset-dir checkpoint-dir 102 | 103 | Train or finetune HiFi-GAN. 104 | 105 | positional arguments: 106 | dataset-dir path to the preprocessed data directory 107 | checkpoint-dir path to the checkpoint directory 108 | 109 | optional arguments: 110 | -h, --help show this help message and exit 111 | --resume RESUME path to the checkpoint to resume from 112 | --finetune whether to finetune (note that a resume path must be given) 113 | ``` 114 | 115 | ## Links 116 | 117 | - [Soft-VC repo](https://github.com/bshall/soft-vc) 118 | - [Soft-VC paper](https://ieeexplore.ieee.org/abstract/document/9746484) 119 | - [HuBERT content encoders](https://github.com/bshall/hubert) 120 | - [Acoustic models](https://github.com/bshall/acoustic-model) 121 | 122 | ## Citation 123 | 124 | If you found this work helpful please consider citing our paper: 125 | 126 | ``` 127 | @inproceedings{ 128 | soft-vc-2022, 129 | author={van Niekerk, Benjamin and Carbonneau, Marc-André and Zaïdi, Julian and Baas, Matthew and Seuté, Hugo and Kamper, Herman}, 130 | booktitle={ICASSP}, 131 | title={A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion}, 132 | year={2022} 133 | } 134 | ``` 135 | 136 | ## Acknowledgements 137 | This repo is based heavily on [https://github.com/jik876/hifi-gan](https://github.com/jik876/hifi-gan). -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | import argparse 4 | import torch 5 | import torchaudio 6 | from tqdm import tqdm 7 | 8 | 9 | def generate(args): 10 | print("Loading checkpoint") 11 | model_name = f"hifigan_hubert_{args.model}" if args.model != "base" else "hifigan" 12 | hifigan = torch.hub.load("bshall/hifigan:main", model_name).cuda() 13 | 14 | print(f"Generating audio from {args.in_dir}") 15 | for path in tqdm(list(args.in_dir.rglob("*.npy"))): 16 | mel = torch.from_numpy(np.load(path)) 17 | mel = mel.unsqueeze(0).cuda() 18 | 19 | wav, sr = hifigan.generate(mel) 20 | wav = wav.squeeze(0).cpu() 21 | 22 | out_path = args.out_dir / path.relative_to(args.in_dir) 23 | out_path.parent.mkdir(exist_ok=True, parents=True) 24 | torchaudio.save(out_path.with_suffix(".wav"), wav, sr) 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser( 29 | description="Generate audio for a directory of mel-spectrogams using HiFi-GAN." 30 | ) 31 | parser.add_argument( 32 | "model", 33 | help="available models (HuBERT-Soft, HuBERT-Discrete, or Base).", 34 | choices=["soft", "discrete", "base"], 35 | ) 36 | parser.add_argument( 37 | "in_dir", 38 | metavar="in-dir", 39 | help="path to input directory containing the mel-spectrograms.", 40 | type=Path, 41 | ) 42 | parser.add_argument( 43 | "out_dir", 44 | metavar="out-dir", 45 | help="path to output directory.", 46 | type=Path, 47 | ) 48 | args = parser.parse_args() 49 | 50 | generate(args) 51 | -------------------------------------------------------------------------------- /hifigan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshall/hifigan/49c16c935cea45ba4073811de974d9ae277c3bc9/hifigan/__init__.py -------------------------------------------------------------------------------- /hifigan/dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import math 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from torch.utils.data import Dataset 9 | 10 | import torchaudio 11 | import torchaudio.transforms as transforms 12 | 13 | 14 | class LogMelSpectrogram(torch.nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | self.melspctrogram = transforms.MelSpectrogram( 18 | sample_rate=16000, 19 | n_fft=1024, 20 | win_length=1024, 21 | hop_length=160, 22 | center=False, 23 | power=1.0, 24 | norm="slaney", 25 | onesided=True, 26 | n_mels=128, 27 | mel_scale="slaney", 28 | ) 29 | 30 | def forward(self, wav): 31 | wav = F.pad(wav, ((1024 - 160) // 2, (1024 - 160) // 2), "reflect") 32 | mel = self.melspctrogram(wav) 33 | logmel = torch.log(torch.clamp(mel, min=1e-5)) 34 | return logmel 35 | 36 | 37 | class MelDataset(Dataset): 38 | def __init__( 39 | self, 40 | root: Path, 41 | segment_length: int, 42 | sample_rate: int, 43 | hop_length: int, 44 | train: bool = True, 45 | finetune: bool = False, 46 | ): 47 | self.wavs_dir = root / "wavs" 48 | self.mels_dir = root / "mels" 49 | self.data_dir = self.wavs_dir if not finetune else self.mels_dir 50 | 51 | self.segment_length = segment_length 52 | self.sample_rate = sample_rate 53 | self.hop_length = hop_length 54 | self.train = train 55 | self.finetune = finetune 56 | 57 | suffix = ".wav" if not finetune else ".npy" 58 | pattern = f"train/**/*{suffix}" if train else f"dev/**/*{suffix}" 59 | 60 | self.metadata = [ 61 | path.relative_to(self.data_dir).with_suffix("") 62 | for path in self.data_dir.rglob(pattern) 63 | ] 64 | 65 | self.logmel = LogMelSpectrogram() 66 | 67 | def __len__(self): 68 | return len(self.metadata) 69 | 70 | def __getitem__(self, index): 71 | path = self.metadata[index] 72 | wav_path = self.wavs_dir / path 73 | 74 | info = torchaudio.info(wav_path.with_suffix(".wav")) 75 | if info.sample_rate != self.sample_rate: 76 | raise ValueError( 77 | f"Sample rate {info.sample_rate} doesn't match target of {self.sample_rate}" 78 | ) 79 | 80 | if self.finetune: 81 | mel_path = self.mels_dir / path 82 | src_logmel = torch.from_numpy(np.load(mel_path.with_suffix(".npy"))) 83 | src_logmel = src_logmel.unsqueeze(0) 84 | 85 | mel_frames_per_segment = math.ceil(self.segment_length / self.hop_length) 86 | mel_diff = src_logmel.size(-1) - mel_frames_per_segment if self.train else 0 87 | mel_offset = random.randint(0, max(mel_diff, 0)) 88 | 89 | frame_offset = self.hop_length * mel_offset 90 | else: 91 | frame_diff = info.num_frames - self.segment_length 92 | frame_offset = random.randint(0, max(frame_diff, 0)) 93 | 94 | wav, _ = torchaudio.load( 95 | filepath=wav_path.with_suffix(".wav"), 96 | frame_offset=frame_offset if self.train else 0, 97 | num_frames=self.segment_length if self.train else -1, 98 | ) 99 | 100 | if wav.size(-1) < self.segment_length: 101 | wav = F.pad(wav, (0, self.segment_length - wav.size(-1))) 102 | 103 | if not self.finetune and self.train: 104 | gain = random.random() * (0.99 - 0.4) + 0.4 105 | flip = -1 if random.random() > 0.5 else 1 106 | wav = flip * gain * wav / max(wav.abs().max(), 1e-5) 107 | 108 | tgt_logmel = self.logmel(wav.unsqueeze(0)).squeeze(0) 109 | 110 | if self.finetune: 111 | if self.train: 112 | src_logmel = src_logmel[ 113 | :, :, mel_offset : mel_offset + mel_frames_per_segment 114 | ] 115 | 116 | if src_logmel.size(-1) < mel_frames_per_segment: 117 | src_logmel = F.pad( 118 | src_logmel, 119 | (0, mel_frames_per_segment - src_logmel.size(-1)), 120 | "constant", 121 | src_logmel.min(), 122 | ) 123 | else: 124 | src_logmel = tgt_logmel.clone() 125 | 126 | return wav, src_logmel, tgt_logmel 127 | -------------------------------------------------------------------------------- /hifigan/discriminator.py: -------------------------------------------------------------------------------- 1 | # adopted from https://github.com/jik876/hifi-gan/blob/master/models.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from typing import Tuple, List 6 | 7 | from hifigan.utils import get_padding 8 | 9 | 10 | LRELU_SLOPE = 0.1 11 | 12 | 13 | class PeriodDiscriminator(torch.nn.Module): 14 | """HiFiGAN Period Discriminator""" 15 | 16 | def __init__( 17 | self, 18 | period: int, 19 | kernel_size: int = 5, 20 | stride: int = 3, 21 | use_spectral_norm: bool = False, 22 | ) -> None: 23 | super().__init__() 24 | self.period = period 25 | norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm 26 | self.convs = nn.ModuleList( 27 | [ 28 | norm_f( 29 | nn.Conv2d( 30 | 1, 31 | 32, 32 | (kernel_size, 1), 33 | (stride, 1), 34 | padding=(get_padding(5, 1), 0), 35 | ) 36 | ), 37 | norm_f( 38 | nn.Conv2d( 39 | 32, 40 | 128, 41 | (kernel_size, 1), 42 | (stride, 1), 43 | padding=(get_padding(5, 1), 0), 44 | ) 45 | ), 46 | norm_f( 47 | nn.Conv2d( 48 | 128, 49 | 512, 50 | (kernel_size, 1), 51 | (stride, 1), 52 | padding=(get_padding(5, 1), 0), 53 | ) 54 | ), 55 | norm_f( 56 | nn.Conv2d( 57 | 512, 58 | 1024, 59 | (kernel_size, 1), 60 | (stride, 1), 61 | padding=(get_padding(5, 1), 0), 62 | ) 63 | ), 64 | norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 65 | ] 66 | ) 67 | self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 68 | 69 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: 70 | """ 71 | Args: 72 | x (Tensor): input waveform. 73 | Returns: 74 | [Tensor]: discriminator scores per sample in the batch. 75 | [List[Tensor]]: list of features from each convolutional layer. 76 | """ 77 | feat = [] 78 | 79 | # 1d to 2d 80 | b, c, t = x.shape 81 | if t % self.period != 0: # pad first 82 | n_pad = self.period - (t % self.period) 83 | x = F.pad(x, (0, n_pad), "reflect") 84 | t = t + n_pad 85 | x = x.view(b, c, t // self.period, self.period) 86 | 87 | for l in self.convs: 88 | x = l(x) 89 | x = F.leaky_relu(x, LRELU_SLOPE) 90 | feat.append(x) 91 | x = self.conv_post(x) 92 | feat.append(x) 93 | x = torch.flatten(x, 1, -1) 94 | 95 | return x, feat 96 | 97 | 98 | class MultiPeriodDiscriminator(torch.nn.Module): 99 | """HiFiGAN Multi-Period Discriminator (MPD)""" 100 | 101 | def __init__(self): 102 | super().__init__() 103 | self.discriminators = nn.ModuleList( 104 | [ 105 | PeriodDiscriminator(2), 106 | PeriodDiscriminator(3), 107 | PeriodDiscriminator(5), 108 | PeriodDiscriminator(7), 109 | PeriodDiscriminator(11), 110 | ] 111 | ) 112 | 113 | def forward( 114 | self, x: torch.Tensor 115 | ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: 116 | """ 117 | Args: 118 | x (Tensor): input waveform. 119 | Returns: 120 | [List[Tensor]]: list of scores from each discriminator. 121 | [List[List[Tensor]]]: list of features from each discriminator's convolutional layers. 122 | """ 123 | scores = [] 124 | feats = [] 125 | for _, d in enumerate(self.discriminators): 126 | score, feat = d(x) 127 | scores.append(score) 128 | feats.append(feat) 129 | return scores, feats 130 | 131 | 132 | class ScaleDiscriminator(torch.nn.Module): 133 | """HiFiGAN Scale Discriminator.""" 134 | 135 | def __init__(self, use_spectral_norm: bool = False) -> None: 136 | super().__init__() 137 | norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm 138 | self.convs = nn.ModuleList( 139 | [ 140 | norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)), 141 | norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)), 142 | norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)), 143 | norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)), 144 | norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 145 | norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 146 | norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), 147 | ] 148 | ) 149 | self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) 150 | 151 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: 152 | """ 153 | Args: 154 | x (Tensor): input waveform. 155 | Returns: 156 | Tensor: discriminator scores. 157 | List[Tensor]: list of features from the convolutional layers. 158 | """ 159 | feat = [] 160 | for l in self.convs: 161 | x = l(x) 162 | x = F.leaky_relu(x, LRELU_SLOPE) 163 | feat.append(x) 164 | x = self.conv_post(x) 165 | feat.append(x) 166 | x = torch.flatten(x, 1, -1) 167 | return x, feat 168 | 169 | 170 | class MultiScaleDiscriminator(torch.nn.Module): 171 | """HiFiGAN Multi-Scale Discriminator.""" 172 | 173 | def __init__(self): 174 | super().__init__() 175 | self.discriminators = nn.ModuleList( 176 | [ 177 | ScaleDiscriminator(use_spectral_norm=True), 178 | ScaleDiscriminator(), 179 | ScaleDiscriminator(), 180 | ] 181 | ) 182 | self.meanpools = nn.ModuleList( 183 | [nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)] 184 | ) 185 | 186 | def forward( 187 | self, x: torch.Tensor 188 | ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: 189 | """ 190 | Args: 191 | x (Tensor): input waveform. 192 | Returns: 193 | List[Tensor]: discriminator scores. 194 | List[List[Tensor]]: list of features from each discriminator's convolutional layers. 195 | """ 196 | scores = [] 197 | feats = [] 198 | for i, d in enumerate(self.discriminators): 199 | if i != 0: 200 | x = self.meanpools[i - 1](x) 201 | score, feat = d(x) 202 | scores.append(score) 203 | feats.append(feat) 204 | return scores, feats 205 | 206 | 207 | class HifiganDiscriminator(nn.Module): 208 | """HiFiGAN discriminator""" 209 | 210 | def __init__(self): 211 | super().__init__() 212 | self.mpd = MultiPeriodDiscriminator() 213 | self.msd = MultiScaleDiscriminator() 214 | 215 | def forward( 216 | self, x: torch.Tensor 217 | ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: 218 | """ 219 | Args: 220 | x (Tensor): input waveform. 221 | Returns: 222 | List[Tensor]: discriminator scores. 223 | List[List[Tensor]]: list of features from from each discriminator's convolutional layers. 224 | """ 225 | scores, feats = self.mpd(x) 226 | scores_, feats_ = self.msd(x) 227 | return scores + scores_, feats + feats_ 228 | 229 | 230 | def feature_loss( 231 | features_real: List[List[torch.Tensor]], features_generate: List[List[torch.Tensor]] 232 | ) -> float: 233 | loss = 0 234 | for r, g in zip(features_real, features_generate): 235 | for rl, gl in zip(r, g): 236 | loss += torch.mean(torch.abs(rl - gl)) 237 | return loss * 2 238 | 239 | 240 | def discriminator_loss(real, generated): 241 | loss = 0 242 | real_losses = [] 243 | generated_losses = [] 244 | for r, g in zip(real, generated): 245 | r_loss = torch.mean((1 - r) ** 2) 246 | g_loss = torch.mean(g ** 2) 247 | loss += r_loss + g_loss 248 | real_losses.append(r_loss.item()) 249 | generated_losses.append(g_loss.item()) 250 | 251 | return loss, real_losses, generated_losses 252 | 253 | 254 | def generator_loss(discriminator_outputs): 255 | loss = 0 256 | generator_losses = [] 257 | for x in discriminator_outputs: 258 | l = torch.mean((1 - x) ** 2) 259 | generator_losses.append(l) 260 | loss += l 261 | 262 | return loss, generator_losses 263 | -------------------------------------------------------------------------------- /hifigan/generator.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/jik876/hifi-gan/blob/master/models.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.utils import remove_weight_norm, weight_norm 6 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 7 | from typing import Tuple 8 | 9 | from hifigan.utils import get_padding 10 | 11 | LRELU_SLOPE = 0.1 12 | 13 | 14 | class HifiganGenerator(torch.nn.Module): 15 | def __init__( 16 | self, 17 | in_channels: int = 128, 18 | resblock_dilation_sizes: Tuple[Tuple[int, ...], ...] = ( 19 | (1, 3, 5), 20 | (1, 3, 5), 21 | (1, 3, 5), 22 | ), 23 | resblock_kernel_sizes: Tuple[int, ...] = (3, 7, 11), 24 | upsample_kernel_sizes: Tuple[int, ...] = (20, 8, 4, 4), 25 | upsample_initial_channel: int = 512, 26 | upsample_factors: int = (10, 4, 2, 2), 27 | inference_padding: int = 5, 28 | sample_rate: int = 16000, 29 | ): 30 | r"""HiFiGAN Generator 31 | Args: 32 | in_channels (int): number of input channels. 33 | resblock_dilation_sizes (Tuple[Tuple[int, ...], ...]): list of dilation values in each layer of a `ResBlock`. 34 | resblock_kernel_sizes (Tuple[int, ...]): list of kernel sizes for each `ResBlock`. 35 | upsample_kernel_sizes (Tuple[int, ...]): list of kernel sizes for each transposed convolution. 36 | upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 37 | for each consecutive upsampling layer. 38 | upsample_factors (Tuple[int, ...]): upsampling factors (stride) for each upsampling layer. 39 | inference_padding (int): constant padding applied to the input at inference time. 40 | sample_rate (int): sample rate of the generated audio. 41 | """ 42 | super().__init__() 43 | self.inference_padding = inference_padding 44 | self.num_kernels = len(resblock_kernel_sizes) 45 | self.num_upsamples = len(upsample_factors) 46 | self.sample_rate = sample_rate 47 | 48 | # initial upsampling layers 49 | self.conv_pre = weight_norm( 50 | nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) 51 | ) 52 | 53 | # upsampling layers 54 | self.ups = nn.ModuleList() 55 | for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): 56 | self.ups.append( 57 | weight_norm( 58 | nn.ConvTranspose1d( 59 | upsample_initial_channel // (2**i), 60 | upsample_initial_channel // (2 ** (i + 1)), 61 | k, 62 | u, 63 | padding=(k - u) // 2, 64 | ) 65 | ) 66 | ) 67 | 68 | # MRF blocks 69 | self.resblocks = nn.ModuleList() 70 | for i in range(len(self.ups)): 71 | ch = upsample_initial_channel // (2 ** (i + 1)) 72 | for _, (k, d) in enumerate( 73 | zip(resblock_kernel_sizes, resblock_dilation_sizes) 74 | ): 75 | self.resblocks.append(ResBlock(ch, k, d)) 76 | 77 | # post convolution layer 78 | self.conv_post = weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3)) 79 | 80 | def forward(self, x: torch.Tensor) -> torch.Tensor: 81 | o = self.conv_pre(x) 82 | for i in range(self.num_upsamples): 83 | o = F.leaky_relu(o, LRELU_SLOPE) 84 | o = self.ups[i](o) 85 | z_sum = None 86 | for j in range(self.num_kernels): 87 | if z_sum is None: 88 | z_sum = self.resblocks[i * self.num_kernels + j](o) 89 | else: 90 | z_sum += self.resblocks[i * self.num_kernels + j](o) 91 | o = z_sum / self.num_kernels 92 | o = F.leaky_relu(o) 93 | o = self.conv_post(o) 94 | o = torch.tanh(o) 95 | return o 96 | 97 | def remove_weight_norm(self): 98 | for layer in self.ups: 99 | remove_weight_norm(layer) 100 | 101 | for layer in self.resblocks: 102 | layer.remove_weight_norm() 103 | 104 | remove_weight_norm(self.conv_pre) 105 | remove_weight_norm(self.conv_post) 106 | 107 | 108 | class ResBlock(torch.nn.Module): 109 | def __init__( 110 | self, channels: int, kernel_size: int = 3, dilation: Tuple[int, ...] = (1, 3, 5) 111 | ) -> None: 112 | super().__init__() 113 | self.convs1 = nn.ModuleList( 114 | [ 115 | weight_norm( 116 | nn.Conv1d( 117 | channels, 118 | channels, 119 | kernel_size, 120 | 1, 121 | dilation=dilation[0], 122 | padding=get_padding(kernel_size, dilation[0]), 123 | ) 124 | ), 125 | weight_norm( 126 | nn.Conv1d( 127 | channels, 128 | channels, 129 | kernel_size, 130 | 1, 131 | dilation=dilation[1], 132 | padding=get_padding(kernel_size, dilation[1]), 133 | ) 134 | ), 135 | weight_norm( 136 | nn.Conv1d( 137 | channels, 138 | channels, 139 | kernel_size, 140 | 1, 141 | dilation=dilation[2], 142 | padding=get_padding(kernel_size, dilation[2]), 143 | ) 144 | ), 145 | ] 146 | ) 147 | 148 | self.convs2 = nn.ModuleList( 149 | [ 150 | weight_norm( 151 | nn.Conv1d( 152 | channels, 153 | channels, 154 | kernel_size, 155 | 1, 156 | dilation=1, 157 | padding=get_padding(kernel_size, 1), 158 | ) 159 | ), 160 | weight_norm( 161 | nn.Conv1d( 162 | channels, 163 | channels, 164 | kernel_size, 165 | 1, 166 | dilation=1, 167 | padding=get_padding(kernel_size, 1), 168 | ) 169 | ), 170 | weight_norm( 171 | nn.Conv1d( 172 | channels, 173 | channels, 174 | kernel_size, 175 | 1, 176 | dilation=1, 177 | padding=get_padding(kernel_size, 1), 178 | ) 179 | ), 180 | ] 181 | ) 182 | 183 | def forward(self, x: torch.Tensor) -> torch.Tensor: 184 | for c1, c2 in zip(self.convs1, self.convs2): 185 | xt = F.leaky_relu(x, LRELU_SLOPE) 186 | xt = c1(xt) 187 | xt = F.leaky_relu(xt, LRELU_SLOPE) 188 | xt = c2(xt) 189 | x = xt + x 190 | return x 191 | 192 | def remove_weight_norm(self): 193 | for l in self.convs1: 194 | remove_weight_norm(l) 195 | for l in self.convs2: 196 | remove_weight_norm(l) 197 | -------------------------------------------------------------------------------- /hifigan/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib 3 | 4 | matplotlib.use("Agg") 5 | import matplotlib.pylab as plt 6 | 7 | 8 | def get_padding(k, d): 9 | return int((k * d - d) / 2) 10 | 11 | 12 | def plot_spectrogram(spectrogram): 13 | fig, ax = plt.subplots(figsize=(10, 2)) 14 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 15 | plt.colorbar(im, ax=ax) 16 | 17 | fig.canvas.draw() 18 | plt.close() 19 | 20 | return fig 21 | 22 | 23 | def save_checkpoint( 24 | checkpoint_dir, 25 | generator, 26 | discriminator, 27 | optimizer_generator, 28 | optimizer_discriminator, 29 | scheduler_generator, 30 | scheduler_discriminator, 31 | step, 32 | loss, 33 | best, 34 | logger, 35 | ): 36 | state = { 37 | "generator": { 38 | "model": generator.state_dict(), 39 | "optimizer": optimizer_generator.state_dict(), 40 | "scheduler": scheduler_generator.state_dict(), 41 | }, 42 | "discriminator": { 43 | "model": discriminator.state_dict(), 44 | "optimizer": optimizer_discriminator.state_dict(), 45 | "scheduler": scheduler_discriminator.state_dict(), 46 | }, 47 | "step": step, 48 | "loss": loss, 49 | } 50 | checkpoint_dir.mkdir(exist_ok=True, parents=True) 51 | checkpoint_path = checkpoint_dir / f"model-{step}.pt" 52 | torch.save(state, checkpoint_path) 53 | if best: 54 | best_path = checkpoint_dir / "model-best.pt" 55 | torch.save(state, best_path) 56 | logger.info(f"Saved checkpoint: {checkpoint_path.stem}") 57 | 58 | 59 | def load_checkpoint( 60 | load_path, 61 | generator, 62 | discriminator, 63 | optimizer_generator, 64 | optimizer_discriminator, 65 | scheduler_generator, 66 | scheduler_discriminator, 67 | rank, 68 | logger, 69 | finetune=False, 70 | ): 71 | logger.info(f"Loading checkpoint from {load_path}") 72 | checkpoint = torch.load(load_path, map_location={"cuda:0": f"cuda:{rank}"}) 73 | generator.load_state_dict(checkpoint["generator"]["model"]) 74 | discriminator.load_state_dict(checkpoint["discriminator"]["model"]) 75 | if not finetune: 76 | optimizer_generator.load_state_dict(checkpoint["generator"]["optimizer"]) 77 | scheduler_generator.load_state_dict(checkpoint["generator"]["scheduler"]) 78 | optimizer_discriminator.load_state_dict( 79 | checkpoint["discriminator"]["optimizer"] 80 | ) 81 | scheduler_discriminator.load_state_dict( 82 | checkpoint["discriminator"]["scheduler"] 83 | ) 84 | return checkpoint["step"], checkpoint["loss"] 85 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ["torch", "torchaudio"] 2 | 3 | URLS = { 4 | "hifigan": "https://github.com/bshall/hifigan/releases/download/v0.1/hifigan-67926ec6.pt", 5 | "hifigan-hubert-discrete": "https://github.com/bshall/hifigan/releases/download/v0.1/hifigan-hubert-discrete-bbad3043.pt", 6 | "hifigan-hubert-soft": "https://github.com/bshall/hifigan/releases/download/v0.1/hifigan-hubert-soft-65f03469.pt", 7 | } 8 | 9 | import torch 10 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 11 | 12 | from hifigan.generator import HifiganGenerator 13 | 14 | 15 | def _hifigan( 16 | name: str, 17 | pretrained: bool = True, 18 | progress: bool = True, 19 | map_location=None, 20 | ) -> HifiganGenerator: 21 | hifigan = HifiganGenerator() 22 | if pretrained: 23 | checkpoint = torch.hub.load_state_dict_from_url( 24 | URLS[name], map_location=map_location, progress=progress 25 | ) 26 | consume_prefix_in_state_dict_if_present(checkpoint, "module.") 27 | hifigan.load_state_dict(checkpoint) 28 | hifigan.eval() 29 | hifigan.remove_weight_norm() 30 | return hifigan 31 | 32 | 33 | def hifigan( 34 | pretrained: bool = True, 35 | progress: bool = True, 36 | map_location=None, 37 | ) -> HifiganGenerator: 38 | """HiFiGAN Vocoder from from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. 39 | Args: 40 | pretrained (bool): load pretrained weights into the model 41 | progress (bool): show progress bar when downloading model 42 | map_location: a function or a dict specifying how to remap storage locations (see torch.load) 43 | """ 44 | return _hifigan("hifigan", pretrained, progress, map_location) 45 | 46 | 47 | def hifigan_hubert_soft( 48 | pretrained: bool = True, 49 | progress: bool = True, 50 | map_location=None, 51 | ) -> HifiganGenerator: 52 | """HiFiGAN Vocoder from from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. 53 | Finetuned on spectrograms generated from the soft acoustic model. 54 | Args: 55 | pretrained (bool): load pretrained weights into the model 56 | progress (bool): show progress bar when downloading model 57 | map_location: a function or a dict specifying how to remap storage locations (see torch.load) 58 | """ 59 | return _hifigan( 60 | "hifigan-hubert-soft", pretrained, progress, map_location=map_location 61 | ) 62 | 63 | 64 | def hifigan_hubert_discrete( 65 | pretrained: bool = True, 66 | progress: bool = True, 67 | map_location=None, 68 | ) -> HifiganGenerator: 69 | """HiFiGAN Vocoder from from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. 70 | Finetuned on spectrograms generated from the discrete acoustic model. 71 | Args: 72 | pretrained (bool): load pretrained weights into the model 73 | progress (bool): show progress bar when downloading model 74 | map_location: a function or a dict specifying how to remap storage locations (see torch.load) 75 | """ 76 | return _hifigan( 77 | "hifigan-hubert-discrete", pretrained, progress, map_location=map_location 78 | ) 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboard==2.7.0 2 | torch==1.9.1 3 | torchaudio==0.9.1 4 | tqdm==4.62.3 -------------------------------------------------------------------------------- /resample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from concurrent.futures import ProcessPoolExecutor 4 | from multiprocessing import cpu_count 5 | 6 | import torchaudio 7 | from torchaudio.functional import resample 8 | 9 | from tqdm import tqdm 10 | 11 | 12 | def process_wav(in_path, out_path, sample_rate): 13 | wav, sr = torchaudio.load(in_path) 14 | wav = resample(wav, sr, sample_rate) 15 | torchaudio.save(out_path, wav, sample_rate) 16 | return out_path, wav.size(-1) / sample_rate 17 | 18 | 19 | def preprocess_dataset(args): 20 | args.out_dir.mkdir(parents=True, exist_ok=True) 21 | 22 | futures = [] 23 | executor = ProcessPoolExecutor(max_workers=cpu_count()) 24 | print(f"Resampling audio in {args.in_dir}") 25 | for in_path in args.in_dir.rglob("*.wav"): 26 | relative_path = in_path.relative_to(args.in_dir) 27 | out_path = args.out_dir / relative_path 28 | out_path.parent.mkdir(parents=True, exist_ok=True) 29 | futures.append( 30 | executor.submit(process_wav, in_path, out_path, args.sample_rate) 31 | ) 32 | 33 | results = [future.result() for future in tqdm(futures)] 34 | 35 | lengths = {path.stem: length for path, length in results} 36 | seconds = sum(lengths.values()) 37 | hours = seconds / 3600 38 | print(f"Wrote {len(lengths)} utterances ({hours:.2f} hours)") 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser(description="Resample an audio dataset.") 43 | parser.add_argument( 44 | "in_dir", metavar="in-dir", help="path to the dataset directory.", type=Path 45 | ) 46 | parser.add_argument( 47 | "out_dir", metavar="out-dir", help="path to the output directory.", type=Path 48 | ) 49 | parser.add_argument( 50 | "--sample-rate", 51 | help="target sample rate (default 16kHz)", 52 | type=int, 53 | default=16000, 54 | ) 55 | args = parser.parse_args() 56 | preprocess_dataset(args) 57 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from pathlib import Path 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from torch.utils.tensorboard import SummaryWriter 10 | import torch.distributed as dist 11 | from torch.utils.data.distributed import DistributedSampler 12 | import torch.multiprocessing as mp 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | 15 | from hifigan.generator import HifiganGenerator 16 | from hifigan.discriminator import ( 17 | HifiganDiscriminator, 18 | feature_loss, 19 | discriminator_loss, 20 | generator_loss, 21 | ) 22 | from hifigan.dataset import MelDataset, LogMelSpectrogram 23 | from hifigan.utils import load_checkpoint, save_checkpoint, plot_spectrogram 24 | 25 | 26 | logging.basicConfig(level=logging.DEBUG) 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | BATCH_SIZE = 8 31 | SEGMENT_LENGTH = 8320 32 | HOP_LENGTH = 160 33 | SAMPLE_RATE = 16000 34 | BASE_LEARNING_RATE = 2e-4 35 | FINETUNE_LEARNING_RATE = 1e-4 36 | BETAS = (0.8, 0.99) 37 | LEARNING_RATE_DECAY = 0.999 38 | WEIGHT_DECAY = 1e-5 39 | EPOCHS = 3100 40 | LOG_INTERVAL = 5 41 | VALIDATION_INTERVAL = 1000 42 | NUM_GENERATED_EXAMPLES = 10 43 | CHECKPOINT_INTERVAL = 5000 44 | 45 | 46 | def train_model(rank, world_size, args): 47 | dist.init_process_group( 48 | "nccl", 49 | rank=rank, 50 | world_size=world_size, 51 | init_method="tcp://localhost:54321", 52 | ) 53 | 54 | log_dir = args.checkpoint_dir / "logs" 55 | log_dir.mkdir(exist_ok=True, parents=True) 56 | 57 | if rank == 0: 58 | logger.setLevel(logging.DEBUG) 59 | handler = logging.FileHandler(log_dir / f"{args.checkpoint_dir.stem}.log") 60 | handler.setLevel(logging.DEBUG) 61 | formatter = logging.Formatter( 62 | "%(asctime)s [%(levelname)s] %(message)s", datefmt="%m/%d/%Y %I:%M:%S" 63 | ) 64 | handler.setFormatter(formatter) 65 | logger.addHandler(handler) 66 | else: 67 | logger.setLevel(logging.ERROR) 68 | 69 | writer = SummaryWriter(log_dir) if rank == 0 else None 70 | 71 | generator = HifiganGenerator().to(rank) 72 | discriminator = HifiganDiscriminator().to(rank) 73 | 74 | generator = DDP(generator, device_ids=[rank]) 75 | discriminator = DDP(discriminator, device_ids=[rank]) 76 | 77 | optimizer_generator = optim.AdamW( 78 | generator.parameters(), 79 | lr=BASE_LEARNING_RATE if not args.finetune else FINETUNE_LEARNING_RATE, 80 | betas=BETAS, 81 | weight_decay=WEIGHT_DECAY, 82 | ) 83 | optimizer_discriminator = optim.AdamW( 84 | discriminator.parameters(), 85 | lr=BASE_LEARNING_RATE if not args.finetune else FINETUNE_LEARNING_RATE, 86 | betas=BETAS, 87 | weight_decay=WEIGHT_DECAY, 88 | ) 89 | 90 | scheduler_generator = optim.lr_scheduler.ExponentialLR( 91 | optimizer_generator, gamma=LEARNING_RATE_DECAY 92 | ) 93 | scheduler_discriminator = optim.lr_scheduler.ExponentialLR( 94 | optimizer_discriminator, gamma=LEARNING_RATE_DECAY 95 | ) 96 | 97 | train_dataset = MelDataset( 98 | root=args.dataset_dir, 99 | segment_length=SEGMENT_LENGTH, 100 | sample_rate=SAMPLE_RATE, 101 | hop_length=HOP_LENGTH, 102 | train=True, 103 | finetune=args.finetune, 104 | ) 105 | train_sampler = DistributedSampler(train_dataset, drop_last=True) 106 | train_loader = DataLoader( 107 | train_dataset, 108 | batch_size=BATCH_SIZE, 109 | sampler=train_sampler, 110 | num_workers=8, 111 | pin_memory=True, 112 | shuffle=False, 113 | drop_last=True, 114 | ) 115 | 116 | validation_dataset = MelDataset( 117 | root=args.dataset_dir, 118 | segment_length=SEGMENT_LENGTH, 119 | sample_rate=SAMPLE_RATE, 120 | hop_length=HOP_LENGTH, 121 | train=False, 122 | finetune=args.finetune, 123 | ) 124 | validation_loader = DataLoader( 125 | validation_dataset, 126 | batch_size=1, 127 | shuffle=False, 128 | num_workers=8, 129 | pin_memory=True, 130 | ) 131 | 132 | melspectrogram = LogMelSpectrogram().to(rank) 133 | 134 | if args.resume is not None: 135 | global_step, best_loss = load_checkpoint( 136 | load_path=args.resume, 137 | generator=generator, 138 | discriminator=discriminator, 139 | optimizer_generator=optimizer_generator, 140 | optimizer_discriminator=optimizer_discriminator, 141 | scheduler_generator=scheduler_generator, 142 | scheduler_discriminator=scheduler_discriminator, 143 | rank=rank, 144 | logger=logger, 145 | finetune=args.finetune, 146 | ) 147 | else: 148 | global_step, best_loss = 0, float("inf") 149 | 150 | if args.finetune: 151 | global_step, best_loss = 0, float("inf") 152 | 153 | n_epochs = EPOCHS 154 | start_epoch = global_step // len(train_loader) + 1 155 | 156 | logger.info("**" * 40) 157 | logger.info(f"batch size: {BATCH_SIZE}") 158 | logger.info(f"iterations per epoch: {len(train_loader)}") 159 | logger.info(f"total of epochs: {n_epochs}") 160 | logger.info(f"started at epoch: {start_epoch}") 161 | logger.info("**" * 40 + "\n") 162 | 163 | for epoch in range(start_epoch, n_epochs + 1): 164 | train_sampler.set_epoch(epoch) 165 | 166 | generator.train() 167 | discriminator.train() 168 | average_loss_mel = average_loss_discriminator = average_loss_generator = 0 169 | for i, (wavs, mels, tgts) in enumerate(train_loader, 1): 170 | wavs, mels, tgts = wavs.to(rank), mels.to(rank), tgts.to(rank) 171 | 172 | # Discriminator 173 | optimizer_discriminator.zero_grad() 174 | 175 | wavs_ = generator(mels.squeeze(1)) 176 | mels_ = melspectrogram(wavs_) 177 | 178 | scores, _ = discriminator(wavs) 179 | scores_, _ = discriminator(wavs_.detach()) 180 | 181 | loss_discriminator, _, _ = discriminator_loss(scores, scores_) 182 | 183 | loss_discriminator.backward() 184 | optimizer_discriminator.step() 185 | 186 | # Generator 187 | optimizer_generator.zero_grad() 188 | 189 | scores, features = discriminator(wavs) 190 | scores_, features_ = discriminator(wavs_) 191 | 192 | loss_mel = F.l1_loss(mels_, tgts) 193 | loss_features = feature_loss(features, features_) 194 | loss_generator_adversarial, _ = generator_loss(scores_) 195 | loss_generator = 45 * loss_mel + loss_features + loss_generator_adversarial 196 | 197 | loss_generator.backward() 198 | optimizer_generator.step() 199 | 200 | global_step += 1 201 | 202 | average_loss_mel += (loss_mel.item() - average_loss_mel) / i 203 | average_loss_discriminator += ( 204 | loss_discriminator.item() - average_loss_discriminator 205 | ) / i 206 | average_loss_generator += ( 207 | loss_generator.item() - average_loss_generator 208 | ) / i 209 | 210 | if rank == 0: 211 | if global_step % LOG_INTERVAL == 0: 212 | writer.add_scalar( 213 | "train/loss_mel", 214 | loss_mel.item(), 215 | global_step, 216 | ) 217 | writer.add_scalar( 218 | "train/loss_generator", 219 | loss_generator.item(), 220 | global_step, 221 | ) 222 | writer.add_scalar( 223 | "train/loss_discriminator", 224 | loss_discriminator.item(), 225 | global_step, 226 | ) 227 | 228 | if global_step % VALIDATION_INTERVAL == 0: 229 | generator.eval() 230 | 231 | average_validation_loss = 0 232 | for j, (wavs, mels, tgts) in enumerate(validation_loader, 1): 233 | wavs, mels, tgts = wavs.to(rank), mels.to(rank), tgts.to(rank) 234 | 235 | with torch.no_grad(): 236 | wavs_ = generator(mels.squeeze(1)) 237 | mels_ = melspectrogram(wavs_) 238 | 239 | length = min(mels_.size(-1), tgts.size(-1)) 240 | 241 | loss_mel = F.l1_loss(mels_[..., :length], tgts[..., :length]) 242 | 243 | average_validation_loss += ( 244 | loss_mel.item() - average_validation_loss 245 | ) / j 246 | 247 | if rank == 0: 248 | if j <= NUM_GENERATED_EXAMPLES: 249 | writer.add_audio( 250 | f"generated/wav_{j}", 251 | wavs_.squeeze(0), 252 | global_step, 253 | sample_rate=16000, 254 | ) 255 | writer.add_figure( 256 | f"generated/mel_{j}", 257 | plot_spectrogram(mels_.squeeze().cpu().numpy()), 258 | global_step, 259 | ) 260 | 261 | generator.train() 262 | discriminator.train() 263 | 264 | if rank == 0: 265 | writer.add_scalar( 266 | "validation/mel_loss", average_validation_loss, global_step 267 | ) 268 | logger.info( 269 | f"valid -- epoch: {epoch}, mel loss: {average_validation_loss:.4f}" 270 | ) 271 | 272 | new_best = best_loss > average_validation_loss 273 | if new_best or global_step % CHECKPOINT_INTERVAL == 0: 274 | if new_best: 275 | logger.info("-------- new best model found!") 276 | best_loss = average_validation_loss 277 | 278 | if rank == 0: 279 | save_checkpoint( 280 | checkpoint_dir=args.checkpoint_dir, 281 | generator=generator, 282 | discriminator=discriminator, 283 | optimizer_generator=optimizer_generator, 284 | optimizer_discriminator=optimizer_discriminator, 285 | scheduler_generator=scheduler_generator, 286 | scheduler_discriminator=scheduler_discriminator, 287 | step=global_step, 288 | loss=average_validation_loss, 289 | best=new_best, 290 | logger=logger, 291 | ) 292 | 293 | scheduler_discriminator.step() 294 | scheduler_generator.step() 295 | 296 | logger.info( 297 | f"train -- epoch: {epoch}, mel loss: {average_loss_mel:.4f}, generator loss: {average_loss_generator:.4f}, discriminator loss: {average_loss_discriminator:.4f}" 298 | ) 299 | 300 | dist.destroy_process_group() 301 | 302 | 303 | if __name__ == "__main__": 304 | parser = argparse.ArgumentParser(description="Train or finetune HiFi-GAN.") 305 | parser.add_argument( 306 | "dataset_dir", 307 | metavar="dataset-dir", 308 | help="path to the preprocessed data directory", 309 | type=Path, 310 | ) 311 | parser.add_argument( 312 | "checkpoint_dir", 313 | metavar="checkpoint-dir", 314 | help="path to the checkpoint directory", 315 | type=Path, 316 | ) 317 | parser.add_argument( 318 | "--resume", 319 | help="path to the checkpoint to resume from", 320 | type=Path, 321 | ) 322 | parser.add_argument( 323 | "--finetune", 324 | help="whether to finetune (note that a resume path must be given)", 325 | action="store_true", 326 | ) 327 | args = parser.parse_args() 328 | 329 | # display training setup info 330 | logger.info(f"PyTorch version: {torch.__version__}") 331 | logger.info(f"CUDA version: {torch.version.cuda}") 332 | logger.info(f"CUDNN version: {torch.backends.cudnn.version()}") 333 | logger.info(f"CUDNN enabled: {torch.backends.cudnn.enabled}") 334 | logger.info(f"CUDNN deterministic: {torch.backends.cudnn.deterministic}") 335 | logger.info(f"CUDNN benchmark: {torch.backends.cudnn.benchmark}") 336 | logger.info(f"# of GPUS: {torch.cuda.device_count()}") 337 | 338 | # clear handlers 339 | logger.handlers.clear() 340 | 341 | world_size = torch.cuda.device_count() 342 | mp.spawn( 343 | train_model, 344 | args=(world_size, args), 345 | nprocs=world_size, 346 | join=True, 347 | ) 348 | -------------------------------------------------------------------------------- /vocoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshall/hifigan/49c16c935cea45ba4073811de974d9ae277c3bc9/vocoder.png --------------------------------------------------------------------------------