├── LICENSE.md ├── README.md ├── hubconf.py ├── mel2wav ├── __init__.py ├── dataset.py ├── interface.py ├── modules.py └── utils.py ├── melgan_slides.pdf ├── models ├── linda_johnson.pt └── multi_speaker.pt ├── requirements.txt ├── scripts ├── generate_from_folder.py └── train.py └── set_env.sh /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Descript Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Official repository for the paper MelGAN: Generative Adversarial Networks for Conditional Waveform Synthesis 2 | 3 | Previous works have found that generating coherent raw audio waveforms with GANs is challenging. In this [paper](https://arxiv.org/abs/1910.06711), we show that it is possible to train GANs reliably to generate high quality coherent waveforms by introducing a set of architectural changes and simple training techniques. Subjective evaluation metric (Mean Opinion Score, or MOS) shows the effectiveness of the proposed approach for high quality mel-spectrogram inversion. To establish the generality of the proposed techniques, we show qualitative results of our model in speech synthesis, music domain translation and unconditional music synthesis. We evaluate the various components of the model through ablation studies and suggest a set of guidelines to design general purpose discriminators and generators for conditional sequence synthesis tasks. Our model is non-autoregressive, fully convolutional, with significantly fewer parameters than competing models and generalizes to unseen speakers for mel-spectrogram inversion. Our pytorch implementation runs at more than 100x faster than realtime on GTX 1080Ti GPU and more than 2x faster than real-time on CPU, without any hardware specific optimization tricks. Blog post with samples and accompanying code coming soon. 4 | 5 | Visit our [website](https://melgan-neurips.github.io) for samples. You can try the speech correction application [here](https://www.descript.com/overdub) created based on the end-to-end speech synthesis pipeline using MelGAN. 6 | 7 | Check the [slides](melgan_slides.pdf) if you aren't attending the NeurIPS 2019 conference to check out our poster. 8 | 9 | 10 | ## Code organization 11 | 12 | ├── README.md <- Top-level README. 13 | ├── set_env.sh <- Set PYTHONPATH and CUDA_VISIBLE_DEVICES. 14 | │ 15 | ├── mel2wav 16 | │   ├── dataset.py <- data loader scripts 17 | │   ├── modules.py <- Model, layers and losses 18 | │   ├── utils.py <- Utilities to monitor, save, log, schedule etc. 19 | │ 20 | ├── scripts 21 | │ ├── train.py <- training / validation / etc scripts 22 | │ ├── generate_from_folder.py 23 | 24 | 25 | ## Preparing dataset 26 | Create a raw folder with all the samples stored in `wavs/` subfolder. 27 | Run these commands: 28 | ```command 29 | ls wavs/*.wav | tail -n+10 > train_files.txt 30 | ls wavs/*.wav | head -n10 > test_files.txt 31 | ``` 32 | 33 | ## Training Example 34 | . source set_env.sh 0 35 | # Set PYTHONPATH and use first GPU 36 | python scripts/train.py --save_path logs/baseline --path 37 | 38 | 39 | ## PyTorch Hub Example 40 | import torch 41 | vocoder = torch.hub.load('descriptinc/melgan-neurips', 'load_melgan') 42 | vocoder.inverse(audio) # audio (torch.tensor) -> (batch_size, 80, timesteps) 43 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ["torch", "librosa", "yaml"] 2 | from mel2wav import MelVocoder 3 | 4 | 5 | def load_melgan(model_name="multi_speaker"): 6 | """ 7 | Exposes a MelVocoder Interface 8 | Args: 9 | model_name (str): Supports only 2 models, 'linda_johnson' or 'multi_speaker' 10 | Returns: 11 | object (MelVocoder): MelVocoder class. 12 | Default function (___call__) converts raw audio to mel 13 | inverse function convert mel to raw audio using MelGAN 14 | """ 15 | 16 | return MelVocoder(path=None, github=True, model_name=model_name) 17 | -------------------------------------------------------------------------------- /mel2wav/__init__.py: -------------------------------------------------------------------------------- 1 | from mel2wav.interface import load_model, MelVocoder 2 | -------------------------------------------------------------------------------- /mel2wav/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torch.nn.functional as F 4 | 5 | from librosa.core import load 6 | from librosa.util import normalize 7 | 8 | from pathlib import Path 9 | import numpy as np 10 | import random 11 | 12 | 13 | def files_to_list(filename): 14 | """ 15 | Takes a text file of filenames and makes a list of filenames 16 | """ 17 | with open(filename, encoding="utf-8") as f: 18 | files = f.readlines() 19 | 20 | files = [f.rstrip() for f in files] 21 | return files 22 | 23 | 24 | class AudioDataset(torch.utils.data.Dataset): 25 | """ 26 | This is the main class that calculates the spectrogram and returns the 27 | spectrogram, audio pair. 28 | """ 29 | 30 | def __init__(self, training_files, segment_length, sampling_rate, augment=True): 31 | self.sampling_rate = sampling_rate 32 | self.segment_length = segment_length 33 | self.audio_files = files_to_list(training_files) 34 | self.audio_files = [Path(training_files).parent / x for x in self.audio_files] 35 | random.seed(1234) 36 | random.shuffle(self.audio_files) 37 | self.augment = augment 38 | 39 | def __getitem__(self, index): 40 | # Read audio 41 | filename = self.audio_files[index] 42 | audio, sampling_rate = self.load_wav_to_torch(filename) 43 | # Take segment 44 | if audio.size(0) >= self.segment_length: 45 | max_audio_start = audio.size(0) - self.segment_length 46 | audio_start = random.randint(0, max_audio_start) 47 | audio = audio[audio_start : audio_start + self.segment_length] 48 | else: 49 | audio = F.pad( 50 | audio, (0, self.segment_length - audio.size(0)), "constant" 51 | ).data 52 | 53 | # audio = audio / 32768.0 54 | return audio.unsqueeze(0) 55 | 56 | def __len__(self): 57 | return len(self.audio_files) 58 | 59 | def load_wav_to_torch(self, full_path): 60 | """ 61 | Loads wavdata into torch array 62 | """ 63 | data, sampling_rate = load(full_path, sr=self.sampling_rate) 64 | data = 0.95 * normalize(data) 65 | 66 | if self.augment: 67 | amplitude = np.random.uniform(low=0.3, high=1.0) 68 | data = data * amplitude 69 | 70 | return torch.from_numpy(data).float(), sampling_rate 71 | -------------------------------------------------------------------------------- /mel2wav/interface.py: -------------------------------------------------------------------------------- 1 | from mel2wav.modules import Generator, Audio2Mel 2 | 3 | from pathlib import Path 4 | import yaml 5 | import torch 6 | import os 7 | 8 | 9 | def get_default_device(): 10 | if torch.cuda.is_available(): 11 | return "cuda" 12 | else: 13 | return "cpu" 14 | 15 | 16 | def load_model(mel2wav_path, device=get_default_device()): 17 | """ 18 | Args: 19 | mel2wav_path (str or Path): path to the root folder of dumped text2mel 20 | device (str or torch.device): device to load the model 21 | """ 22 | root = Path(mel2wav_path) 23 | with open(root / "args.yml", "r") as f: 24 | args = yaml.load(f, Loader=yaml.FullLoader) 25 | netG = Generator(args.n_mel_channels, args.ngf, args.n_residual_layers).to(device) 26 | netG.load_state_dict(torch.load(root / "best_netG.pt", map_location=device)) 27 | return netG 28 | 29 | 30 | class MelVocoder: 31 | def __init__( 32 | self, 33 | path, 34 | device=get_default_device(), 35 | github=False, 36 | model_name="multi_speaker", 37 | ): 38 | self.fft = Audio2Mel().to(device) 39 | if github: 40 | netG = Generator(80, 32, 3).to(device) 41 | root = Path(os.path.dirname(__file__)).parent 42 | netG.load_state_dict( 43 | torch.load(root / f"models/{model_name}.pt", map_location=device) 44 | ) 45 | self.mel2wav = netG 46 | else: 47 | self.mel2wav = load_model(path, device) 48 | self.device = device 49 | 50 | def __call__(self, audio): 51 | """ 52 | Performs audio to mel conversion (See Audio2Mel in mel2wav/modules.py) 53 | Args: 54 | audio (torch.tensor): PyTorch tensor containing audio (batch_size, timesteps) 55 | Returns: 56 | torch.tensor: log-mel-spectrogram computed on input audio (batch_size, 80, timesteps) 57 | """ 58 | return self.fft(audio.unsqueeze(1).to(self.device)) 59 | 60 | def inverse(self, mel): 61 | """ 62 | Performs mel2audio conversion 63 | Args: 64 | mel (torch.tensor): PyTorch tensor containing log-mel spectrograms (batch_size, 80, timesteps) 65 | Returns: 66 | torch.tensor: Inverted raw audio (batch_size, timesteps) 67 | 68 | """ 69 | with torch.no_grad(): 70 | return self.mel2wav(mel.to(self.device)).squeeze(1) 71 | -------------------------------------------------------------------------------- /mel2wav/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from librosa.filters import mel as librosa_mel_fn 5 | from torch.nn.utils import weight_norm 6 | import numpy as np 7 | 8 | 9 | def weights_init(m): 10 | classname = m.__class__.__name__ 11 | if classname.find("Conv") != -1: 12 | m.weight.data.normal_(0.0, 0.02) 13 | elif classname.find("BatchNorm2d") != -1: 14 | m.weight.data.normal_(1.0, 0.02) 15 | m.bias.data.fill_(0) 16 | 17 | 18 | def WNConv1d(*args, **kwargs): 19 | return weight_norm(nn.Conv1d(*args, **kwargs)) 20 | 21 | 22 | def WNConvTranspose1d(*args, **kwargs): 23 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 24 | 25 | 26 | class Audio2Mel(nn.Module): 27 | def __init__( 28 | self, 29 | n_fft=1024, 30 | hop_length=256, 31 | win_length=1024, 32 | sampling_rate=22050, 33 | n_mel_channels=80, 34 | mel_fmin=0.0, 35 | mel_fmax=None, 36 | ): 37 | super().__init__() 38 | ############################################## 39 | # FFT Parameters # 40 | ############################################## 41 | window = torch.hann_window(win_length).float() 42 | mel_basis = librosa_mel_fn( 43 | sampling_rate, n_fft, n_mel_channels, mel_fmin, mel_fmax 44 | ) 45 | mel_basis = torch.from_numpy(mel_basis).float() 46 | self.register_buffer("mel_basis", mel_basis) 47 | self.register_buffer("window", window) 48 | self.n_fft = n_fft 49 | self.hop_length = hop_length 50 | self.win_length = win_length 51 | self.sampling_rate = sampling_rate 52 | self.n_mel_channels = n_mel_channels 53 | 54 | def forward(self, audio): 55 | p = (self.n_fft - self.hop_length) // 2 56 | audio = F.pad(audio, (p, p), "reflect").squeeze(1) 57 | fft = torch.stft( 58 | audio, 59 | n_fft=self.n_fft, 60 | hop_length=self.hop_length, 61 | win_length=self.win_length, 62 | window=self.window, 63 | center=False, 64 | ) 65 | real_part, imag_part = fft.unbind(-1) 66 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) 67 | mel_output = torch.matmul(self.mel_basis, magnitude) 68 | log_mel_spec = torch.log10(torch.clamp(mel_output, min=1e-5)) 69 | return log_mel_spec 70 | 71 | 72 | class ResnetBlock(nn.Module): 73 | def __init__(self, dim, dilation=1): 74 | super().__init__() 75 | self.block = nn.Sequential( 76 | nn.LeakyReLU(0.2), 77 | nn.ReflectionPad1d(dilation), 78 | WNConv1d(dim, dim, kernel_size=3, dilation=dilation), 79 | nn.LeakyReLU(0.2), 80 | WNConv1d(dim, dim, kernel_size=1), 81 | ) 82 | self.shortcut = WNConv1d(dim, dim, kernel_size=1) 83 | 84 | def forward(self, x): 85 | return self.shortcut(x) + self.block(x) 86 | 87 | 88 | class Generator(nn.Module): 89 | def __init__(self, input_size, ngf, n_residual_layers): 90 | super().__init__() 91 | ratios = [8, 8, 2, 2] 92 | self.hop_length = np.prod(ratios) 93 | mult = int(2 ** len(ratios)) 94 | 95 | model = [ 96 | nn.ReflectionPad1d(3), 97 | WNConv1d(input_size, mult * ngf, kernel_size=7, padding=0), 98 | ] 99 | 100 | # Upsample to raw audio scale 101 | for i, r in enumerate(ratios): 102 | model += [ 103 | nn.LeakyReLU(0.2), 104 | WNConvTranspose1d( 105 | mult * ngf, 106 | mult * ngf // 2, 107 | kernel_size=r * 2, 108 | stride=r, 109 | padding=r // 2 + r % 2, 110 | output_padding=r % 2, 111 | ), 112 | ] 113 | 114 | for j in range(n_residual_layers): 115 | model += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)] 116 | 117 | mult //= 2 118 | 119 | model += [ 120 | nn.LeakyReLU(0.2), 121 | nn.ReflectionPad1d(3), 122 | WNConv1d(ngf, 1, kernel_size=7, padding=0), 123 | nn.Tanh(), 124 | ] 125 | 126 | self.model = nn.Sequential(*model) 127 | self.apply(weights_init) 128 | 129 | def forward(self, x): 130 | return self.model(x) 131 | 132 | 133 | class NLayerDiscriminator(nn.Module): 134 | def __init__(self, ndf, n_layers, downsampling_factor): 135 | super().__init__() 136 | model = nn.ModuleDict() 137 | 138 | model["layer_0"] = nn.Sequential( 139 | nn.ReflectionPad1d(7), 140 | WNConv1d(1, ndf, kernel_size=15), 141 | nn.LeakyReLU(0.2, True), 142 | ) 143 | 144 | nf = ndf 145 | stride = downsampling_factor 146 | for n in range(1, n_layers + 1): 147 | nf_prev = nf 148 | nf = min(nf * stride, 1024) 149 | 150 | model["layer_%d" % n] = nn.Sequential( 151 | WNConv1d( 152 | nf_prev, 153 | nf, 154 | kernel_size=stride * 10 + 1, 155 | stride=stride, 156 | padding=stride * 5, 157 | groups=nf_prev // 4, 158 | ), 159 | nn.LeakyReLU(0.2, True), 160 | ) 161 | 162 | nf = min(nf * 2, 1024) 163 | model["layer_%d" % (n_layers + 1)] = nn.Sequential( 164 | WNConv1d(nf_prev, nf, kernel_size=5, stride=1, padding=2), 165 | nn.LeakyReLU(0.2, True), 166 | ) 167 | 168 | model["layer_%d" % (n_layers + 2)] = WNConv1d( 169 | nf, 1, kernel_size=3, stride=1, padding=1 170 | ) 171 | 172 | self.model = model 173 | 174 | def forward(self, x): 175 | results = [] 176 | for key, layer in self.model.items(): 177 | x = layer(x) 178 | results.append(x) 179 | return results 180 | 181 | 182 | class Discriminator(nn.Module): 183 | def __init__(self, num_D, ndf, n_layers, downsampling_factor): 184 | super().__init__() 185 | self.model = nn.ModuleDict() 186 | for i in range(num_D): 187 | self.model[f"disc_{i}"] = NLayerDiscriminator( 188 | ndf, n_layers, downsampling_factor 189 | ) 190 | 191 | self.downsample = nn.AvgPool1d(4, stride=2, padding=1, count_include_pad=False) 192 | self.apply(weights_init) 193 | 194 | def forward(self, x): 195 | results = [] 196 | for key, disc in self.model.items(): 197 | results.append(disc(x)) 198 | x = self.downsample(x) 199 | return results 200 | -------------------------------------------------------------------------------- /mel2wav/utils.py: -------------------------------------------------------------------------------- 1 | import scipy.io.wavfile 2 | 3 | 4 | def save_sample(file_path, sampling_rate, audio): 5 | """Helper function to save sample 6 | 7 | Args: 8 | file_path (str or pathlib.Path): save file path 9 | sampling_rate (int): sampling rate of audio (usually 22050) 10 | audio (torch.FloatTensor): torch array containing audio in [-1, 1] 11 | """ 12 | audio = (audio.numpy() * 32768).astype("int16") 13 | scipy.io.wavfile.write(file_path, sampling_rate, audio) 14 | -------------------------------------------------------------------------------- /melgan_slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/descriptinc/melgan-neurips/6488045bfba1975602288de07a58570c7b4d66ea/melgan_slides.pdf -------------------------------------------------------------------------------- /models/linda_johnson.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/descriptinc/melgan-neurips/6488045bfba1975602288de07a58570c7b4d66ea/models/linda_johnson.pt -------------------------------------------------------------------------------- /models/multi_speaker.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/descriptinc/melgan-neurips/6488045bfba1975602288de07a58570c7b4d66ea/models/multi_speaker.pt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | librosa 3 | pyyaml 4 | scipy 5 | argparse -------------------------------------------------------------------------------- /scripts/generate_from_folder.py: -------------------------------------------------------------------------------- 1 | from mel2wav import MelVocoder 2 | 3 | from pathlib import Path 4 | from tqdm import tqdm 5 | import argparse 6 | import librosa 7 | import torch 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--load_path", type=Path, required=True) 13 | parser.add_argument("--save_path", type=Path, required=True) 14 | parser.add_argument("--folder", type=Path, required=True) 15 | args = parser.parse_args() 16 | return args 17 | 18 | 19 | def main(): 20 | args = parse_args() 21 | vocoder = MelVocoder(args.load_path) 22 | 23 | args.save_path.mkdir(exist_ok=True, parents=True) 24 | 25 | for i, fname in tqdm(enumerate(args.folder.glob("*.wav"))): 26 | wavname = fname.name 27 | wav, sr = librosa.core.load(fname) 28 | 29 | mel, _ = vocoder(torch.from_numpy(wav)[None]) 30 | recons = vocoder.inverse(mel).squeeze().cpu().numpy() 31 | 32 | librosa.output.write_wav(args.save_path / wavname, recons, sr=sr) 33 | 34 | 35 | if __name__ == "__main__": 36 | main() 37 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | from mel2wav.dataset import AudioDataset 2 | from mel2wav.modules import Generator, Discriminator, Audio2Mel 3 | from mel2wav.utils import save_sample 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | import yaml 11 | import numpy as np 12 | import time 13 | import argparse 14 | from pathlib import Path 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--save_path", required=True) 20 | parser.add_argument("--load_path", default=None) 21 | 22 | parser.add_argument("--n_mel_channels", type=int, default=80) 23 | parser.add_argument("--ngf", type=int, default=32) 24 | parser.add_argument("--n_residual_layers", type=int, default=3) 25 | 26 | parser.add_argument("--ndf", type=int, default=16) 27 | parser.add_argument("--num_D", type=int, default=3) 28 | parser.add_argument("--n_layers_D", type=int, default=4) 29 | parser.add_argument("--downsamp_factor", type=int, default=4) 30 | parser.add_argument("--lambda_feat", type=float, default=10) 31 | parser.add_argument("--cond_disc", action="store_true") 32 | 33 | parser.add_argument("--data_path", default=None, type=Path) 34 | parser.add_argument("--batch_size", type=int, default=16) 35 | parser.add_argument("--seq_len", type=int, default=8192) 36 | 37 | parser.add_argument("--epochs", type=int, default=3000) 38 | parser.add_argument("--log_interval", type=int, default=100) 39 | parser.add_argument("--save_interval", type=int, default=1000) 40 | parser.add_argument("--n_test_samples", type=int, default=8) 41 | args = parser.parse_args() 42 | return args 43 | 44 | 45 | def main(): 46 | args = parse_args() 47 | 48 | root = Path(args.save_path) 49 | load_root = Path(args.load_path) if args.load_path else None 50 | root.mkdir(parents=True, exist_ok=True) 51 | 52 | #################################### 53 | # Dump arguments and create logger # 54 | #################################### 55 | with open(root / "args.yml", "w") as f: 56 | yaml.dump(args, f) 57 | writer = SummaryWriter(str(root)) 58 | 59 | ####################### 60 | # Load PyTorch Models # 61 | ####################### 62 | netG = Generator(args.n_mel_channels, args.ngf, args.n_residual_layers).cuda() 63 | netD = Discriminator( 64 | args.num_D, args.ndf, args.n_layers_D, args.downsamp_factor 65 | ).cuda() 66 | fft = Audio2Mel(n_mel_channels=args.n_mel_channels).cuda() 67 | 68 | print(netG) 69 | print(netD) 70 | 71 | ##################### 72 | # Create optimizers # 73 | ##################### 74 | optG = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9)) 75 | optD = torch.optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9)) 76 | 77 | if load_root and load_root.exists(): 78 | netG.load_state_dict(torch.load(load_root / "netG.pt")) 79 | optG.load_state_dict(torch.load(load_root / "optG.pt")) 80 | netD.load_state_dict(torch.load(load_root / "netD.pt")) 81 | optD.load_state_dict(torch.load(load_root / "optD.pt")) 82 | 83 | ####################### 84 | # Create data loaders # 85 | ####################### 86 | train_set = AudioDataset( 87 | Path(args.data_path) / "train_files.txt", args.seq_len, sampling_rate=22050 88 | ) 89 | test_set = AudioDataset( 90 | Path(args.data_path) / "test_files.txt", 91 | 22050 * 4, 92 | sampling_rate=22050, 93 | augment=False, 94 | ) 95 | 96 | train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=4) 97 | test_loader = DataLoader(test_set, batch_size=1) 98 | 99 | ########################## 100 | # Dumping original audio # 101 | ########################## 102 | test_voc = [] 103 | test_audio = [] 104 | for i, x_t in enumerate(test_loader): 105 | x_t = x_t.cuda() 106 | s_t = fft(x_t).detach() 107 | 108 | test_voc.append(s_t.cuda()) 109 | test_audio.append(x_t) 110 | 111 | audio = x_t.squeeze().cpu() 112 | save_sample(root / ("original_%d.wav" % i), 22050, audio) 113 | writer.add_audio("original/sample_%d.wav" % i, audio, 0, sample_rate=22050) 114 | 115 | if i == args.n_test_samples - 1: 116 | break 117 | 118 | costs = [] 119 | start = time.time() 120 | 121 | # enable cudnn autotuner to speed up training 122 | torch.backends.cudnn.benchmark = True 123 | 124 | best_mel_reconst = 1000000 125 | steps = 0 126 | for epoch in range(1, args.epochs + 1): 127 | for iterno, x_t in enumerate(train_loader): 128 | x_t = x_t.cuda() 129 | s_t = fft(x_t).detach() 130 | x_pred_t = netG(s_t.cuda()) 131 | 132 | with torch.no_grad(): 133 | s_pred_t = fft(x_pred_t.detach()) 134 | s_error = F.l1_loss(s_t, s_pred_t).item() 135 | 136 | ####################### 137 | # Train Discriminator # 138 | ####################### 139 | D_fake_det = netD(x_pred_t.cuda().detach()) 140 | D_real = netD(x_t.cuda()) 141 | 142 | loss_D = 0 143 | for scale in D_fake_det: 144 | loss_D += F.relu(1 + scale[-1]).mean() 145 | 146 | for scale in D_real: 147 | loss_D += F.relu(1 - scale[-1]).mean() 148 | 149 | netD.zero_grad() 150 | loss_D.backward() 151 | optD.step() 152 | 153 | ################### 154 | # Train Generator # 155 | ################### 156 | D_fake = netD(x_pred_t.cuda()) 157 | 158 | loss_G = 0 159 | for scale in D_fake: 160 | loss_G += -scale[-1].mean() 161 | 162 | loss_feat = 0 163 | feat_weights = 4.0 / (args.n_layers_D + 1) 164 | D_weights = 1.0 / args.num_D 165 | wt = D_weights * feat_weights 166 | for i in range(args.num_D): 167 | for j in range(len(D_fake[i]) - 1): 168 | loss_feat += wt * F.l1_loss(D_fake[i][j], D_real[i][j].detach()) 169 | 170 | netG.zero_grad() 171 | (loss_G + args.lambda_feat * loss_feat).backward() 172 | optG.step() 173 | 174 | ###################### 175 | # Update tensorboard # 176 | ###################### 177 | costs.append([loss_D.item(), loss_G.item(), loss_feat.item(), s_error]) 178 | 179 | writer.add_scalar("loss/discriminator", costs[-1][0], steps) 180 | writer.add_scalar("loss/generator", costs[-1][1], steps) 181 | writer.add_scalar("loss/feature_matching", costs[-1][2], steps) 182 | writer.add_scalar("loss/mel_reconstruction", costs[-1][3], steps) 183 | steps += 1 184 | 185 | if steps % args.save_interval == 0: 186 | st = time.time() 187 | with torch.no_grad(): 188 | for i, (voc, _) in enumerate(zip(test_voc, test_audio)): 189 | pred_audio = netG(voc) 190 | pred_audio = pred_audio.squeeze().cpu() 191 | save_sample(root / ("generated_%d.wav" % i), 22050, pred_audio) 192 | writer.add_audio( 193 | "generated/sample_%d.wav" % i, 194 | pred_audio, 195 | epoch, 196 | sample_rate=22050, 197 | ) 198 | 199 | torch.save(netG.state_dict(), root / "netG.pt") 200 | torch.save(optG.state_dict(), root / "optG.pt") 201 | 202 | torch.save(netD.state_dict(), root / "netD.pt") 203 | torch.save(optD.state_dict(), root / "optD.pt") 204 | 205 | if np.asarray(costs).mean(0)[-1] < best_mel_reconst: 206 | best_mel_reconst = np.asarray(costs).mean(0)[-1] 207 | torch.save(netD.state_dict(), root / "best_netD.pt") 208 | torch.save(netG.state_dict(), root / "best_netG.pt") 209 | 210 | print("Took %5.4fs to generate samples" % (time.time() - st)) 211 | print("-" * 100) 212 | 213 | if steps % args.log_interval == 0: 214 | print( 215 | "Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}".format( 216 | epoch, 217 | iterno, 218 | len(train_loader), 219 | 1000 * (time.time() - start) / args.log_interval, 220 | np.asarray(costs).mean(0), 221 | ) 222 | ) 223 | costs = [] 224 | start = time.time() 225 | 226 | 227 | if __name__ == "__main__": 228 | main() 229 | -------------------------------------------------------------------------------- /set_env.sh: -------------------------------------------------------------------------------- 1 | YELLOW='\033[1;33m' 2 | GREEN='\033[0;32m' 3 | NC='\033[0m' # No Color 4 | 5 | if [ "$1" = "" ] 6 | then 7 | echo -e "\n${YELLOW}Warning: Not using GPU ${NC}\n" 8 | else 9 | echo -e "\n${GREEN}Using CUDA device $1 ${NC}\n" 10 | fi 11 | 12 | export CUDA_VISIBLE_DEVICES=$1 13 | export PYTHONPATH=$PWD:$PYTHONPATH 14 | --------------------------------------------------------------------------------