├── DataBaker ├── val.txt ├── 000001.wav ├── 000002.wav ├── 000003.wav ├── 000004.wav ├── 000005.wav ├── 000006.wav ├── 000007.wav ├── 000008.wav ├── 000009.wav ├── 000010.wav └── train.txt ├── export_onnx.py ├── test_files └── 000010.wav ├── generated_files ├── 000010.wav ├── hifi-gan000010_generated.wav └── adahifi-gan000010_generated.wav ├── requirements.txt ├── env.py ├── config_v1.json ├── ada_config_v1.json ├── LICENSE ├── utils.py ├── preprocess_aishell3.py ├── inference_e2e.py ├── AISHELL-3 └── val.txt ├── inference.py ├── README.md ├── meldataset.py ├── train_hifi_gan.py ├── train_ada_hifi_gan.py └── models.py /DataBaker/val.txt: -------------------------------------------------------------------------------- 1 | 000010| 2 | -------------------------------------------------------------------------------- /export_onnx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /DataBaker/000001.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000001.wav -------------------------------------------------------------------------------- /DataBaker/000002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000002.wav -------------------------------------------------------------------------------- /DataBaker/000003.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000003.wav -------------------------------------------------------------------------------- /DataBaker/000004.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000004.wav -------------------------------------------------------------------------------- /DataBaker/000005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000005.wav -------------------------------------------------------------------------------- /DataBaker/000006.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000006.wav -------------------------------------------------------------------------------- /DataBaker/000007.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000007.wav -------------------------------------------------------------------------------- /DataBaker/000008.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000008.wav -------------------------------------------------------------------------------- /DataBaker/000009.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000009.wav -------------------------------------------------------------------------------- /DataBaker/000010.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000010.wav -------------------------------------------------------------------------------- /test_files/000010.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/test_files/000010.wav -------------------------------------------------------------------------------- /generated_files/000010.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/generated_files/000010.wav -------------------------------------------------------------------------------- /DataBaker/train.txt: -------------------------------------------------------------------------------- 1 | 000001| 2 | 000002| 3 | 000003| 4 | 000004| 5 | 000005| 6 | 000006| 7 | 000007| 8 | 000008| 9 | 000009| 10 | -------------------------------------------------------------------------------- /generated_files/hifi-gan000010_generated.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/generated_files/hifi-gan000010_generated.wav -------------------------------------------------------------------------------- /generated_files/adahifi-gan000010_generated.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/generated_files/adahifi-gan000010_generated.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | numpy==1.17.4 3 | librosa==0.7.2 4 | scipy==1.4.1 5 | tensorboard==2.0 6 | soundfile==0.10.3.post1 7 | matplotlib==3.1.3 -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | 10 | 11 | def build_env(config, config_name, path): 12 | t_path = os.path.join(path, config_name) 13 | if config != t_path: 14 | os.makedirs(path, exist_ok=True) 15 | shutil.copyfile(config, os.path.join(path, config_name)) 16 | -------------------------------------------------------------------------------- /config_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 32, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 22050, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 4, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /ada_config_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 8, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 22050, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 4, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Xin Yuan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import matplotlib 4 | import torch 5 | from torch.nn.utils import weight_norm 6 | matplotlib.use("Agg") 7 | import matplotlib.pylab as plt 8 | 9 | 10 | def plot_spectrogram(spectrogram): 11 | fig, ax = plt.subplots(figsize=(10, 2)) 12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 13 | interpolation='none') 14 | plt.colorbar(im, ax=ax) 15 | 16 | fig.canvas.draw() 17 | plt.close() 18 | 19 | return fig 20 | 21 | 22 | def init_weights(m, mean=0.0, std=0.01): 23 | classname = m.__class__.__name__ 24 | if classname.find("Conv") != -1: 25 | m.weight.data.normal_(mean, std) 26 | 27 | 28 | def apply_weight_norm(m): 29 | classname = m.__class__.__name__ 30 | if classname.find("Conv") != -1: 31 | weight_norm(m) 32 | 33 | 34 | def get_padding(kernel_size, dilation=1): 35 | return int((kernel_size*dilation - dilation)/2) 36 | 37 | 38 | def load_checkpoint(filepath, device): 39 | assert os.path.isfile(filepath) 40 | print("Loading '{}'".format(filepath)) 41 | checkpoint_dict = torch.load(filepath, map_location=device) 42 | print("Complete.") 43 | return checkpoint_dict 44 | 45 | 46 | def save_checkpoint(filepath, obj): 47 | print("Saving checkpoint to {}".format(filepath)) 48 | torch.save(obj, filepath) 49 | print("Complete.") 50 | 51 | 52 | def scan_checkpoint(cp_dir, prefix): 53 | pattern = os.path.join(cp_dir, prefix + '????????') 54 | cp_list = glob.glob(pattern) 55 | if len(cp_list) == 0: 56 | return None 57 | return sorted(cp_list)[-1] 58 | 59 | -------------------------------------------------------------------------------- /preprocess_aishell3.py: -------------------------------------------------------------------------------- 1 | # import os 2 | # from shutil import copyfile 3 | # import librosa 4 | # from scipy.io import wavfile 5 | # import numpy as np 6 | # from tqdm import tqdm 7 | # 8 | # save_path = '/home/admin/yuanxin/2.TTSData/AISHELL-3/wavs' 9 | # 10 | # train_path = '/home/admin/yuanxin/2.TTSData/AISHELL-3/train/wav' 11 | # train_file = '/home/admin/yuanxin/2.TTSData/AISHELL-3/train/content.txt' 12 | # test_path = '/home/admin/yuanxin/2.TTSData/AISHELL-3/test/wav' 13 | # test_file = '/home/admin/yuanxin/2.TTSData/AISHELL-3/test/content.txt' 14 | # 15 | # 16 | # file_list = os.listdir(train_path) 17 | # for speaker_name in file_list: 18 | # wave_list = os.listdir(os.path.join(train_path, speaker_name)) 19 | # for wavname in wave_list: 20 | # copyfile(os.path.join(train_path, speaker_name, wavname), os.path.join(save_path, wavname)) 21 | # 22 | # file_list = os.listdir(test_path) 23 | # for speaker_name in file_list: 24 | # wave_list = os.listdir(os.path.join(test_path, speaker_name)) 25 | # for wavname in wave_list: 26 | # copyfile(os.path.join(test_path, speaker_name, wavname), os.path.join(save_path, wavname)) 27 | # 28 | 29 | # with open(train_file, 'r', encoding='utf-8') as f: 30 | # lines = f.readlines() 31 | # with open('AISHELL-3/train.txt', 'w', encoding='utf-8') as f1: 32 | # for l in lines: 33 | # l = l.split('.')[0] + '|' 34 | # f1.write(l + '\n') 35 | # 36 | # with open(test_file, 'r', encoding='utf-8') as f: 37 | # lines = f.readlines() 38 | # with open('AISHELL-3/val.txt', 'w', encoding='utf-8') as f1: 39 | # for l in lines: 40 | # l = l.split('.')[0] + '|' 41 | # f1.write(l + '\n') 42 | # 43 | 44 | # file_list = os.listdir(save_path) 45 | # for file in tqdm(file_list): 46 | # wav, _ = librosa.load(os.path.join(save_path, file), 22050) 47 | # wav = wav / max(abs(wav)) * 32767.0 48 | # wavfile.write( 49 | # os.path.join(save_path, file), 50 | # 22050, 51 | # wav.astype(np.int16), 52 | # ) 53 | -------------------------------------------------------------------------------- /inference_e2e.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import glob 4 | import os 5 | import numpy as np 6 | import argparse 7 | import json 8 | import torch 9 | from scipy.io.wavfile import write 10 | from env import AttrDict 11 | from meldataset import MAX_WAV_VALUE 12 | from models import Generator 13 | 14 | h = None 15 | device = None 16 | 17 | 18 | def load_checkpoint(filepath, device): 19 | assert os.path.isfile(filepath) 20 | print("Loading '{}'".format(filepath)) 21 | checkpoint_dict = torch.load(filepath, map_location=device) 22 | print("Complete.") 23 | return checkpoint_dict 24 | 25 | 26 | def scan_checkpoint(cp_dir, prefix): 27 | pattern = os.path.join(cp_dir, prefix + '*') 28 | cp_list = glob.glob(pattern) 29 | if len(cp_list) == 0: 30 | return '' 31 | return sorted(cp_list)[-1] 32 | 33 | 34 | def inference(a): 35 | generator = Generator(h).to(device) 36 | 37 | state_dict_g = load_checkpoint(a.checkpoint_file, device) 38 | generator.load_state_dict(state_dict_g['generator']) 39 | 40 | filelist = os.listdir(a.input_mels_dir) 41 | 42 | os.makedirs(a.output_dir, exist_ok=True) 43 | 44 | generator.eval() 45 | generator.remove_weight_norm() 46 | with torch.no_grad(): 47 | for i, filname in enumerate(filelist): 48 | x = np.load(os.path.join(a.input_mels_dir, filname)) 49 | x = torch.FloatTensor(x).to(device) 50 | y_g_hat = generator(x) 51 | audio = y_g_hat.squeeze() 52 | audio = audio * MAX_WAV_VALUE 53 | audio = audio.cpu().numpy().astype('int16') 54 | 55 | output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated_e2e.wav') 56 | write(output_file, h.sampling_rate, audio) 57 | print(output_file) 58 | 59 | 60 | def main(): 61 | print('Initializing Inference Process..') 62 | 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument('--input_mels_dir', default='test_mel_files') 65 | parser.add_argument('--output_dir', default='generated_files_from_mel') 66 | parser.add_argument('--checkpoint_file', required=True) 67 | a = parser.parse_args() 68 | 69 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json') 70 | with open(config_file) as f: 71 | data = f.read() 72 | 73 | global h 74 | json_config = json.loads(data) 75 | h = AttrDict(json_config) 76 | 77 | torch.manual_seed(h.seed) 78 | global device 79 | if torch.cuda.is_available(): 80 | torch.cuda.manual_seed(h.seed) 81 | device = torch.device('cuda') 82 | else: 83 | device = torch.device('cpu') 84 | 85 | inference(a) 86 | 87 | 88 | if __name__ == '__main__': 89 | main() 90 | 91 | -------------------------------------------------------------------------------- /AISHELL-3/val.txt: -------------------------------------------------------------------------------- 1 | SSB06930002| 2 | SSB06930003| 3 | SSB06930004| 4 | SSB06930005| 5 | SSB06930006| 6 | SSB06930007| 7 | SSB06930008| 8 | SSB06930010| 9 | SSB06930011| 10 | SSB06930012| 11 | SSB06930013| 12 | SSB06930014| 13 | SSB06930015| 14 | SSB06930016| 15 | SSB06930017| 16 | SSB06930018| 17 | SSB06930019| 18 | SSB06930021| 19 | SSB06930022| 20 | SSB06930023| 21 | SSB06930024| 22 | SSB06930025| 23 | SSB06930026| 24 | SSB06930027| 25 | SSB06930028| 26 | SSB06930029| 27 | SSB06930030| 28 | SSB06930031| 29 | SSB06930032| 30 | SSB06930033| 31 | SSB06930034| 32 | SSB06930035| 33 | SSB06930036| 34 | SSB06930037| 35 | SSB06930039| 36 | SSB06930040| 37 | SSB06930041| 38 | SSB06930042| 39 | SSB06930044| 40 | SSB06930045| 41 | SSB06930047| 42 | SSB06930048| 43 | SSB06930049| 44 | SSB06930051| 45 | SSB06930053| 46 | SSB06930054| 47 | SSB06930056| 48 | SSB06930057| 49 | SSB06930058| 50 | SSB06930059| 51 | SSB06930060| 52 | SSB06930062| 53 | SSB06930063| 54 | SSB06930064| 55 | SSB06930065| 56 | SSB06930066| 57 | SSB06930067| 58 | SSB06930068| 59 | SSB06930070| 60 | SSB06930071| 61 | SSB06930072| 62 | SSB06930074| 63 | SSB06930075| 64 | SSB06930076| 65 | SSB06930077| 66 | SSB06930078| 67 | SSB06930079| 68 | SSB06930080| 69 | SSB06930020| 70 | SSB06930038| 71 | SSB06930061| 72 | SSB06930081| 73 | SSB06930099| 74 | SSB06930120| 75 | SSB06930141| 76 | SSB06930160| 77 | SSB06930181| 78 | SSB06930201| 79 | SSB06930219| 80 | SSB06930238| 81 | SSB06930257| 82 | SSB06930277| 83 | SSB06930295| 84 | SSB06930314| 85 | SSB06930335| 86 | SSB06930354| 87 | SSB06930374| 88 | SSB06930392| 89 | SSB06930412| 90 | SSB06930432| 91 | SSB06930451| 92 | SSB06930082| 93 | SSB06930083| 94 | SSB06930084| 95 | SSB06930085| 96 | SSB06930086| 97 | SSB06930087| 98 | SSB06930088| 99 | SSB06930089| 100 | SSB06930090| 101 | SSB06930091| 102 | SSB06930092| 103 | SSB06930093| 104 | SSB06930094| 105 | SSB06930095| 106 | SSB06930096| 107 | SSB06930097| 108 | SSB06930098| 109 | SSB06930100| 110 | SSB06930101| 111 | SSB06930102| 112 | SSB06930104| 113 | SSB06930105| 114 | SSB06930106| 115 | SSB06930108| 116 | SSB06930109| 117 | SSB06930110| 118 | SSB06930111| 119 | SSB06930112| 120 | SSB06930113| 121 | SSB06930114| 122 | SSB06930115| 123 | SSB06930117| 124 | SSB06930118| 125 | SSB06930119| 126 | SSB06930121| 127 | SSB06930122| 128 | SSB06930123| 129 | SSB06930124| 130 | SSB06930126| 131 | SSB06930127| 132 | SSB06930128| 133 | SSB06930129| 134 | SSB06930130| 135 | SSB06930131| 136 | SSB06930132| 137 | SSB06930134| 138 | SSB06930135| 139 | SSB06930136| 140 | SSB06930137| 141 | SSB06930138| 142 | SSB06930140| 143 | SSB06930142| 144 | SSB06930143| 145 | SSB06930144| 146 | SSB06930145| 147 | SSB06930146| 148 | SSB06930147| 149 | SSB06930148| 150 | SSB06930149| 151 | SSB06930150| 152 | SSB06930151| 153 | SSB06930152| 154 | SSB06930153| 155 | SSB06930154| 156 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import glob 4 | import os 5 | import argparse 6 | import json 7 | 8 | import numpy as np 9 | import torch 10 | from scipy.io.wavfile import write 11 | from env import AttrDict 12 | from meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav 13 | from models import Generator 14 | 15 | h = None 16 | device = None 17 | 18 | 19 | def load_checkpoint(filepath, device): 20 | assert os.path.isfile(filepath) 21 | print("Loading '{}'".format(filepath)) 22 | checkpoint_dict = torch.load(filepath, map_location=device) 23 | print("Complete.") 24 | return checkpoint_dict 25 | 26 | 27 | def get_mel(x): 28 | return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax) 29 | 30 | 31 | def scan_checkpoint(cp_dir, prefix): 32 | pattern = os.path.join(cp_dir, prefix + '*') 33 | cp_list = glob.glob(pattern) 34 | if len(cp_list) == 0: 35 | return '' 36 | return sorted(cp_list)[-1] 37 | 38 | 39 | def inference(a): 40 | generator = Generator(h).to(device) 41 | 42 | state_dict_g = load_checkpoint(a.checkpoint_file, device) 43 | generator.load_state_dict(state_dict_g['generator']) 44 | 45 | filelist = os.listdir(a.input_wavs_dir) 46 | 47 | os.makedirs(a.output_dir, exist_ok=True) 48 | 49 | generator.eval() 50 | generator.remove_weight_norm() 51 | with torch.no_grad(): 52 | for i, filname in enumerate(filelist): 53 | wav, sr = load_wav(os.path.join(a.input_wavs_dir, filname)) 54 | wav = wav / MAX_WAV_VALUE 55 | wav = torch.FloatTensor(wav).to(device) 56 | x = get_mel(wav.unsqueeze(0)) 57 | y_g_hat = generator(x) 58 | audio = y_g_hat.squeeze() 59 | audio = audio * MAX_WAV_VALUE 60 | audio = audio.cpu().numpy().astype('int16') 61 | 62 | output_file = os.path.join(a.output_dir, a.model_name + os.path.splitext(filname)[0] + '_generated.wav') 63 | write(output_file, h.sampling_rate, audio) 64 | print(output_file) 65 | 66 | 67 | def main(): 68 | print('Initializing Inference Process..') 69 | 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('--input_wavs_dir', default='test_files') 72 | parser.add_argument('--output_dir', default='generated_files') 73 | parser.add_argument('--model_name', required=True) 74 | parser.add_argument('--checkpoint_file', required=True) 75 | a = parser.parse_args() 76 | 77 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json') 78 | with open(config_file) as f: 79 | data = f.read() 80 | 81 | global h 82 | json_config = json.loads(data) 83 | h = AttrDict(json_config) 84 | 85 | torch.manual_seed(h.seed) 86 | global device 87 | if torch.cuda.is_available(): 88 | torch.cuda.manual_seed(h.seed) 89 | device = torch.device('cuda') 90 | else: 91 | device = torch.device('cpu') 92 | 93 | inference(a) 94 | 95 | 96 | if __name__ == '__main__': 97 | main() 98 | 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaVocoder: Adaptive Vocoder for Custom Voice 2 | 3 | In our [paper](https://www.isca-speech.org/archive/interspeech_2022/yuan22_interspeech.html), 4 | we proposed AdaVocoder: Adaptive Vocoder for Custom Voice.
5 | We provide our implementation and pretrained models for `AdaHiFi-GAN` as open source in this repository. 6 | 7 | **Abstract :** 8 | 9 | Custom voice is to construct a personal speech synthesis system by adapting the source speech synthesis model 10 | to the target model through the target few recordings. The solution to constructing a custom voice is 11 | to combine an adaptive acoustic model with a robust vocoder. However, training a robust vocoder usually requires a multi-speaker dataset, 12 | which should include various age groups and various timbres, so that the trained vocoder can be used for unseen speakers. Collecting such a 13 | multi-speaker dataset is difficult, and the dataset distribution always has a mismatch with the distribution of the target speaker dataset. 14 | 15 | This paper proposes an adaptive vocoder for custom voice from another novel perspective to solve the above problems. 16 | The adaptive vocoder mainly uses a cross-domain consistency loss to solve the overfitting problem encountered by the GAN-based neural 17 | vocoder in the transfer learning of few-shot scenes. We construct two adaptive vocoders, AdaMelGAN and AdaHiFi-GAN. 18 | First, We pre-train the source vocoder model on AISHELL3 and CSMSC datasets, respectively. 19 | Then, fine-tune it on the internal dataset VXI-children with few adaptation data. 20 | The empirical results show that a high-quality custom voice system can be built by combining a adaptive acoustic model with a adaptive vocoder. 21 | 22 | ## Pre-requisites 23 | 1. Python >= 3.6 24 | 2. Clone this repository. 25 | 3. Install python requirements. Please refer [requirements.txt](requirements.txt) 26 | 4. Download and extract the [AISHELL3 dataset](http://www.aishelltech.com/aishell_3), then rename or create a link to the dataset folder: `ln -s /path/to/AISHELL-3/wavs DUMMY1` 27 | And move all wav files to `AISHELL-3/wavs`, and sample all audio files to `22050`Hz. 28 | 29 | 30 | ## Training HiFi-GAN 31 | ``` 32 | python train_hifi_gan.py --config config_v1.json 33 | ``` 34 | - Tensorboard 35 | ``` 36 | tensorboard --logdir cp_hifigan/logs/ --bind_all 37 | ``` 38 | 39 | Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
40 | You can change the path by adding `--checkpoint_path` option. 41 | 42 | ## Pretrained Model 43 | You can also use pretrained models we provide.
44 | [Download AISHELL3 pretrained models](https://drive.google.com/file/d/1lqp-8mQIultA2nQ9lY3SNyUpqDpZTHLk/view?usp=sharing) 45 | 46 | 47 | ## Training AdaHiFi-GAN 48 | First you need to save the pre-trained `AISHELL-3` model to `cp_ada_hifigan`.
49 | 50 | Due to the need for confidentiality, VXI-children is not used here. I tested it using the child sample shared by [Data-Baker](https://www.data-baker.com/). 51 | 52 | ``` 53 | python train_ada_hifi_gan.py --config config_v1.json 54 | ``` 55 | - Tensorboard 56 | ``` 57 | tensorboard --logdir cp_ada_hifigan/logs/ --bind_all 58 | ``` 59 | Checkpoints and copy of the configuration file are saved in `cp_ada_hifigan` directory by default.
60 | You can change the path by adding `--checkpoint_path` option. 61 | 62 | 63 | ## Inference from wav file 64 | 1. Make `test_files` directory and copy wav files into the directory. 65 | 2. Run the following command. 66 | ``` 67 | python inference.py --checkpoint_file [generator checkpoint file path] --model_name [hifi-gan or adahifi-gan] 68 | ``` 69 | Generated wav files are saved in `generated_files` by default.
70 | You can change the path by adding `--output_dir` option. 71 | 72 | ## [Some Sample](https://yuan1615.github.io/2022/09/21/AdaVocoder/) 73 | 74 | 75 | ## Acknowledgements 76 | We referred to [HiFi-GAN](https://github.com/jik876/hifi-gan) to implement this. 77 | 78 | -------------------------------------------------------------------------------- /meldataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | import torch.utils.data 6 | import numpy as np 7 | from librosa.util import normalize 8 | from scipy.io.wavfile import read 9 | from librosa.filters import mel as librosa_mel_fn 10 | 11 | MAX_WAV_VALUE = 32768.0 12 | 13 | 14 | def load_wav(full_path): 15 | sampling_rate, data = read(full_path) 16 | return data, sampling_rate 17 | 18 | 19 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 20 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 21 | 22 | 23 | def dynamic_range_decompression(x, C=1): 24 | return np.exp(x) / C 25 | 26 | 27 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 28 | return torch.log(torch.clamp(x, min=clip_val) * C) 29 | 30 | 31 | def dynamic_range_decompression_torch(x, C=1): 32 | return torch.exp(x) / C 33 | 34 | 35 | def spectral_normalize_torch(magnitudes): 36 | output = dynamic_range_compression_torch(magnitudes) 37 | return output 38 | 39 | 40 | def spectral_de_normalize_torch(magnitudes): 41 | output = dynamic_range_decompression_torch(magnitudes) 42 | return output 43 | 44 | 45 | mel_basis = {} 46 | hann_window = {} 47 | 48 | 49 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 50 | if torch.min(y) < -1.: 51 | print('min value is ', torch.min(y)) 52 | if torch.max(y) > 1.: 53 | print('max value is ', torch.max(y)) 54 | 55 | global mel_basis, hann_window 56 | if fmax not in mel_basis: 57 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 58 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) 59 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 60 | 61 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 62 | y = y.squeeze(1) 63 | 64 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], 65 | center=center, pad_mode='reflect', normalized=False, onesided=True) 66 | 67 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) 68 | 69 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) 70 | spec = spectral_normalize_torch(spec) 71 | 72 | return spec 73 | 74 | 75 | def get_dataset_filelist(a): 76 | with open(a.input_training_file, 'r', encoding='utf-8') as fi: 77 | training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') 78 | for x in fi.read().split('\n') if len(x) > 0] 79 | 80 | with open(a.input_validation_file, 'r', encoding='utf-8') as fi: 81 | validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') 82 | for x in fi.read().split('\n') if len(x) > 0] 83 | return training_files, validation_files 84 | 85 | 86 | class MelDataset(torch.utils.data.Dataset): 87 | def __init__(self, training_files, segment_size, n_fft, num_mels, 88 | hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1, 89 | device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None): 90 | self.audio_files = training_files 91 | random.seed(1234) 92 | if shuffle: 93 | random.shuffle(self.audio_files) 94 | self.segment_size = segment_size 95 | self.sampling_rate = sampling_rate 96 | self.split = split 97 | self.n_fft = n_fft 98 | self.num_mels = num_mels 99 | self.hop_size = hop_size 100 | self.win_size = win_size 101 | self.fmin = fmin 102 | self.fmax = fmax 103 | self.fmax_loss = fmax_loss 104 | self.cached_wav = None 105 | self.n_cache_reuse = n_cache_reuse 106 | self._cache_ref_count = 0 107 | self.device = device 108 | self.fine_tuning = fine_tuning 109 | self.base_mels_path = base_mels_path 110 | 111 | def __getitem__(self, index): 112 | filename = self.audio_files[index] 113 | if self._cache_ref_count == 0: 114 | audio, sampling_rate = load_wav(filename) 115 | audio = audio / MAX_WAV_VALUE 116 | if not self.fine_tuning: 117 | audio = normalize(audio) * 0.95 118 | self.cached_wav = audio 119 | if sampling_rate != self.sampling_rate: 120 | raise ValueError("{} SR doesn't match target {} SR".format( 121 | sampling_rate, self.sampling_rate)) 122 | self._cache_ref_count = self.n_cache_reuse 123 | else: 124 | audio = self.cached_wav 125 | self._cache_ref_count -= 1 126 | 127 | audio = torch.FloatTensor(audio) 128 | audio = audio.unsqueeze(0) 129 | 130 | if not self.fine_tuning: 131 | if self.split: 132 | if audio.size(1) >= self.segment_size: 133 | max_audio_start = audio.size(1) - self.segment_size 134 | audio_start = random.randint(0, max_audio_start) 135 | audio = audio[:, audio_start:audio_start+self.segment_size] 136 | else: 137 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') 138 | 139 | mel = mel_spectrogram(audio, self.n_fft, self.num_mels, 140 | self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, 141 | center=False) 142 | else: 143 | mel = np.load( 144 | os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy')) 145 | mel = torch.from_numpy(mel) 146 | 147 | if len(mel.shape) < 3: 148 | mel = mel.unsqueeze(0) 149 | 150 | if self.split: 151 | frames_per_seg = math.ceil(self.segment_size / self.hop_size) 152 | 153 | if audio.size(1) >= self.segment_size: 154 | mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) 155 | mel = mel[:, :, mel_start:mel_start + frames_per_seg] 156 | audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size] 157 | else: 158 | mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant') 159 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') 160 | 161 | mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels, 162 | self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss, 163 | center=False) 164 | 165 | return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) 166 | 167 | def __len__(self): 168 | return len(self.audio_files) 169 | -------------------------------------------------------------------------------- /train_hifi_gan.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.simplefilter(action='ignore', category=FutureWarning) 4 | import itertools 5 | import os 6 | import time 7 | import argparse 8 | import json 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.utils.data import DistributedSampler, DataLoader 13 | import torch.multiprocessing as mp 14 | from torch.distributed import init_process_group 15 | from torch.nn.parallel import DistributedDataParallel 16 | from env import AttrDict, build_env 17 | from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist 18 | from models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss, \ 19 | discriminator_loss 20 | from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint 21 | 22 | torch.backends.cudnn.benchmark = True 23 | 24 | 25 | def train(rank, a, h): 26 | if h.num_gpus > 1: 27 | init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], 28 | world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) 29 | 30 | torch.cuda.manual_seed(h.seed) 31 | device = torch.device('cuda:{:d}'.format(rank)) 32 | 33 | generator = Generator(h).to(device) 34 | mpd = MultiPeriodDiscriminator().to(device) 35 | msd = MultiScaleDiscriminator().to(device) 36 | 37 | if rank == 0: 38 | print(generator) 39 | os.makedirs(a.checkpoint_path, exist_ok=True) 40 | print("checkpoints directory : ", a.checkpoint_path) 41 | 42 | if os.path.isdir(a.checkpoint_path): 43 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_') 44 | cp_do = scan_checkpoint(a.checkpoint_path, 'do_') 45 | 46 | steps = 0 47 | if cp_g is None or cp_do is None: 48 | state_dict_do = None 49 | last_epoch = -1 50 | else: 51 | state_dict_g = load_checkpoint(cp_g, device) 52 | state_dict_do = load_checkpoint(cp_do, device) 53 | generator.load_state_dict(state_dict_g['generator']) 54 | mpd.load_state_dict(state_dict_do['mpd']) 55 | msd.load_state_dict(state_dict_do['msd']) 56 | steps = state_dict_do['steps'] + 1 57 | last_epoch = state_dict_do['epoch'] 58 | 59 | if h.num_gpus > 1: 60 | generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) 61 | mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) 62 | msd = DistributedDataParallel(msd, device_ids=[rank]).to(device) 63 | 64 | optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) 65 | optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()), 66 | h.learning_rate, betas=[h.adam_b1, h.adam_b2]) 67 | 68 | if state_dict_do is not None: 69 | optim_g.load_state_dict(state_dict_do['optim_g']) 70 | optim_d.load_state_dict(state_dict_do['optim_d']) 71 | 72 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) 73 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) 74 | 75 | training_filelist, validation_filelist = get_dataset_filelist(a) 76 | 77 | trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels, 78 | h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0, 79 | shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device, 80 | fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir) 81 | 82 | train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None 83 | 84 | train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, 85 | sampler=train_sampler, 86 | batch_size=h.batch_size, 87 | pin_memory=True, 88 | drop_last=True) 89 | 90 | if rank == 0: 91 | validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels, 92 | h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0, 93 | fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, 94 | base_mels_path=a.input_mels_dir) 95 | validation_loader = DataLoader(validset, num_workers=1, shuffle=False, 96 | sampler=None, 97 | batch_size=1, 98 | pin_memory=True, 99 | drop_last=True) 100 | 101 | sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs')) 102 | 103 | generator.train() 104 | mpd.train() 105 | msd.train() 106 | for epoch in range(max(0, last_epoch), a.training_epochs): 107 | if rank == 0: 108 | start = time.time() 109 | print("Epoch: {}".format(epoch + 1)) 110 | 111 | if h.num_gpus > 1: 112 | train_sampler.set_epoch(epoch) 113 | 114 | for i, batch in enumerate(train_loader): 115 | if rank == 0: 116 | start_b = time.time() 117 | x, y, _, y_mel = batch 118 | x = torch.autograd.Variable(x.to(device, non_blocking=True)) 119 | y = torch.autograd.Variable(y.to(device, non_blocking=True)) 120 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) 121 | y = y.unsqueeze(1) 122 | # print(y.shape) 123 | # y.shape: [16, 1, 8192] 124 | y_g_hat = generator(x) 125 | y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, 126 | h.win_size, 127 | h.fmin, h.fmax_for_loss) 128 | 129 | optim_d.zero_grad() 130 | 131 | # MPD 132 | y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) 133 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) 134 | 135 | # MSD 136 | y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) 137 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) 138 | 139 | loss_disc_all = loss_disc_s + loss_disc_f 140 | 141 | loss_disc_all.backward() 142 | optim_d.step() 143 | 144 | # Generator 145 | optim_g.zero_grad() 146 | 147 | # L1 Mel-Spectrogram Loss 148 | loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45 149 | 150 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) 151 | y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) 152 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) 153 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) 154 | loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) 155 | loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) 156 | loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel 157 | 158 | loss_gen_all.backward() 159 | optim_g.step() 160 | 161 | if rank == 0: 162 | # STDOUT logging 163 | if steps % a.stdout_interval == 0: 164 | with torch.no_grad(): 165 | mel_error = F.l1_loss(y_mel, y_g_hat_mel).item() 166 | 167 | print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'. 168 | format(steps, loss_gen_all, mel_error, time.time() - start_b)) 169 | 170 | # checkpointing 171 | if steps % a.checkpoint_interval == 0 and steps != 0: 172 | checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps) 173 | save_checkpoint(checkpoint_path, 174 | {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()}) 175 | checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps) 176 | save_checkpoint(checkpoint_path, 177 | {'mpd': (mpd.module if h.num_gpus > 1 178 | else mpd).state_dict(), 179 | 'msd': (msd.module if h.num_gpus > 1 180 | else msd).state_dict(), 181 | 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 182 | 'epoch': epoch}) 183 | 184 | # Tensorboard summary logging 185 | if steps % a.summary_interval == 0: 186 | sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) 187 | sw.add_scalar("training/mel_spec_error", mel_error, steps) 188 | 189 | # Validation 190 | if steps % a.validation_interval == 0: # and steps != 0: 191 | generator.eval() 192 | torch.cuda.empty_cache() 193 | val_err_tot = 0 194 | with torch.no_grad(): 195 | for j, batch in enumerate(validation_loader): 196 | x, y, _, y_mel = batch 197 | y_g_hat = generator(x.to(device)) 198 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) 199 | y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, 200 | h.hop_size, h.win_size, 201 | h.fmin, h.fmax_for_loss) 202 | val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item() 203 | 204 | if j <= 4: 205 | if steps == 0: 206 | sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate) 207 | sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps) 208 | 209 | sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate) 210 | y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, 211 | h.sampling_rate, h.hop_size, h.win_size, 212 | h.fmin, h.fmax) 213 | sw.add_figure('generated/y_hat_spec_{}'.format(j), 214 | plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps) 215 | 216 | val_err = val_err_tot / (j + 1) 217 | sw.add_scalar("validation/mel_spec_error", val_err, steps) 218 | 219 | generator.train() 220 | 221 | steps += 1 222 | 223 | scheduler_g.step() 224 | scheduler_d.step() 225 | 226 | if rank == 0: 227 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start))) 228 | 229 | 230 | def main(): 231 | print('Initializing Training Process..') 232 | 233 | parser = argparse.ArgumentParser() 234 | 235 | parser.add_argument('--group_name', default=None) 236 | parser.add_argument('--input_wavs_dir', default='DUMMY1') 237 | parser.add_argument('--input_mels_dir', default='ft_dataset') 238 | parser.add_argument('--input_training_file', 239 | default='AISHELL-3/train.txt') 240 | parser.add_argument('--input_validation_file', 241 | default='AISHELL-3/val.txt') 242 | parser.add_argument('--checkpoint_path', default='cp_hifigan') 243 | parser.add_argument('--config', default='') 244 | parser.add_argument('--training_epochs', default=3100, type=int) 245 | parser.add_argument('--stdout_interval', default=5, type=int) 246 | parser.add_argument('--checkpoint_interval', default=10000, type=int) 247 | parser.add_argument('--summary_interval', default=100, type=int) 248 | parser.add_argument('--validation_interval', default=1000, type=int) 249 | parser.add_argument('--fine_tuning', default=False, type=bool) 250 | 251 | a = parser.parse_args() 252 | 253 | with open(a.config) as f: 254 | data = f.read() 255 | 256 | json_config = json.loads(data) 257 | h = AttrDict(json_config) 258 | build_env(a.config, 'config.json', a.checkpoint_path) 259 | 260 | torch.manual_seed(h.seed) 261 | if torch.cuda.is_available(): 262 | torch.cuda.manual_seed(h.seed) 263 | h.num_gpus = torch.cuda.device_count() 264 | h.batch_size = int(h.batch_size / h.num_gpus) 265 | print('Batch size per GPU :', h.batch_size) 266 | else: 267 | pass 268 | 269 | if h.num_gpus > 1: 270 | mp.spawn(train, nprocs=h.num_gpus, args=(a, h,)) 271 | else: 272 | train(0, a, h) 273 | 274 | 275 | if __name__ == '__main__': 276 | main() 277 | -------------------------------------------------------------------------------- /train_ada_hifi_gan.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.simplefilter(action='ignore', category=FutureWarning) 4 | import itertools 5 | import os 6 | import time 7 | import argparse 8 | import json 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.utils.data import DistributedSampler, DataLoader 13 | import torch.multiprocessing as mp 14 | from torch.distributed import init_process_group 15 | from torch.nn.parallel import DistributedDataParallel 16 | from env import AttrDict, build_env 17 | from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist 18 | from models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss, \ 19 | discriminator_loss, Generator_S, Generator_T, ada_loss 20 | from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint 21 | 22 | torch.backends.cudnn.benchmark = True 23 | 24 | 25 | def train(rank, a, h): 26 | if h.num_gpus > 1: 27 | init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], 28 | world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) 29 | 30 | torch.cuda.manual_seed(h.seed) 31 | device = torch.device('cuda:{:d}'.format(rank)) 32 | 33 | generator_s = Generator_S(h).to(device) 34 | generator_t = Generator_T(h).to(device) 35 | 36 | mpd = MultiPeriodDiscriminator().to(device) 37 | msd = MultiScaleDiscriminator().to(device) 38 | 39 | if rank == 0: 40 | print(generator_t) 41 | os.makedirs(a.checkpoint_path, exist_ok=True) 42 | print("checkpoints directory : ", a.checkpoint_path) 43 | 44 | if os.path.isdir(a.checkpoint_path): 45 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_') 46 | cp_do = scan_checkpoint(a.checkpoint_path, 'do_') 47 | 48 | steps = 0 49 | if cp_g is None or cp_do is None: 50 | state_dict_do = None 51 | last_epoch = -1 52 | else: 53 | state_dict_g = load_checkpoint(cp_g, device) 54 | state_dict_do = load_checkpoint(cp_do, device) 55 | generator_s.load_state_dict(state_dict_g['generator']) 56 | generator_t.load_state_dict(state_dict_g['generator']) 57 | mpd.load_state_dict(state_dict_do['mpd']) 58 | msd.load_state_dict(state_dict_do['msd']) 59 | steps = state_dict_do['steps'] + 1 60 | last_epoch = state_dict_do['epoch'] 61 | 62 | if h.num_gpus > 1: 63 | generator_s = DistributedDataParallel(generator_s, device_ids=[rank]).to(device) 64 | generator_t = DistributedDataParallel(generator_t, device_ids=[rank]).to(device) 65 | mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) 66 | msd = DistributedDataParallel(msd, device_ids=[rank]).to(device) 67 | 68 | optim_g_t = torch.optim.AdamW(generator_t.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) 69 | optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()), 70 | h.learning_rate, betas=[h.adam_b1, h.adam_b2]) 71 | 72 | if state_dict_do is not None: 73 | optim_g_t.load_state_dict(state_dict_do['optim_g']) 74 | optim_d.load_state_dict(state_dict_do['optim_d']) 75 | 76 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g_t, gamma=h.lr_decay, last_epoch=last_epoch) 77 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) 78 | 79 | training_filelist, validation_filelist = get_dataset_filelist(a) 80 | 81 | trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels, 82 | h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0, 83 | shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device, 84 | fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir) 85 | 86 | train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None 87 | 88 | train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, 89 | sampler=train_sampler, 90 | batch_size=h.batch_size, 91 | pin_memory=True, 92 | drop_last=True) 93 | 94 | if rank == 0: 95 | validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels, 96 | h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0, 97 | fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, 98 | base_mels_path=a.input_mels_dir) 99 | validation_loader = DataLoader(validset, num_workers=1, shuffle=False, 100 | sampler=None, 101 | batch_size=1, 102 | pin_memory=True, 103 | drop_last=True) 104 | 105 | sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs')) 106 | 107 | generator_t.train() 108 | mpd.train() 109 | msd.train() 110 | for epoch in range(max(0, last_epoch), a.training_epochs): 111 | if rank == 0: 112 | start = time.time() 113 | print("Epoch: {}".format(epoch + 1)) 114 | 115 | if h.num_gpus > 1: 116 | train_sampler.set_epoch(epoch) 117 | 118 | for i, batch in enumerate(train_loader): 119 | if rank == 0: 120 | start_b = time.time() 121 | x, y, _, y_mel = batch 122 | x = torch.autograd.Variable(x.to(device, non_blocking=True)) 123 | y = torch.autograd.Variable(y.to(device, non_blocking=True)) 124 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) 125 | y = y.unsqueeze(1) 126 | _, fmap_s = generator_s(x) 127 | y_g_hat, fmap_t = generator_t(x) 128 | y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, 129 | h.win_size, 130 | h.fmin, h.fmax_for_loss) 131 | 132 | optim_d.zero_grad() 133 | 134 | # MPD 135 | y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) 136 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) 137 | 138 | # MSD 139 | y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) 140 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) 141 | 142 | loss_disc_all = loss_disc_s + loss_disc_f 143 | 144 | loss_disc_all.backward() 145 | optim_d.step() 146 | 147 | # Generator 148 | optim_g_t.zero_grad() 149 | 150 | # ada loss 151 | a_loss, _ = ada_loss(fmap_s, fmap_t) 152 | a_loss = a_loss * 1000 153 | 154 | # L1 Mel-Spectrogram Loss 155 | loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45 156 | 157 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) 158 | y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) 159 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) 160 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) 161 | loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) 162 | loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) 163 | loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel + a_loss 164 | 165 | loss_gen_all.backward() 166 | optim_g_t.step() 167 | 168 | if rank == 0: 169 | # STDOUT logging 170 | if steps % a.stdout_interval == 0: 171 | with torch.no_grad(): 172 | mel_error = F.l1_loss(y_mel, y_g_hat_mel).item() 173 | 174 | print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'. 175 | format(steps, loss_gen_all, mel_error, time.time() - start_b)) 176 | 177 | # checkpointing 178 | if steps % a.checkpoint_interval == 0 and steps != 0: 179 | checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps) 180 | save_checkpoint(checkpoint_path, 181 | {'generator': (generator_t.module if h.num_gpus > 1 else generator_t).state_dict()}) 182 | checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps) 183 | save_checkpoint(checkpoint_path, 184 | {'mpd': (mpd.module if h.num_gpus > 1 185 | else mpd).state_dict(), 186 | 'msd': (msd.module if h.num_gpus > 1 187 | else msd).state_dict(), 188 | 'optim_g': optim_g_t.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 189 | 'epoch': epoch}) 190 | 191 | # Tensorboard summary logging 192 | if steps % a.summary_interval == 0: 193 | sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) 194 | sw.add_scalar("training/mel_spec_error", mel_error, steps) 195 | sw.add_scalar("training/ada_losses", a_loss, steps) 196 | 197 | # Validation 198 | if steps % a.validation_interval == 0: # and steps != 0: 199 | generator_t.eval() 200 | torch.cuda.empty_cache() 201 | val_err_tot = 0 202 | with torch.no_grad(): 203 | for j, batch in enumerate(validation_loader): 204 | x, y, _, y_mel = batch 205 | y_g_hat, _ = generator_t(x.to(device)) 206 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) 207 | y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, 208 | h.hop_size, h.win_size, 209 | h.fmin, h.fmax_for_loss) 210 | val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item() 211 | 212 | if j <= 4: 213 | # if steps == 0: 214 | # sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate) 215 | # sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps) 216 | sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate) 217 | sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps) 218 | 219 | sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate) 220 | y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, 221 | h.sampling_rate, h.hop_size, h.win_size, 222 | h.fmin, h.fmax) 223 | sw.add_figure('generated/y_hat_spec_{}'.format(j), 224 | plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps) 225 | 226 | val_err = val_err_tot / (j + 1) 227 | sw.add_scalar("validation/mel_spec_error", val_err, steps) 228 | 229 | generator_t.train() 230 | 231 | steps += 1 232 | 233 | scheduler_g.step() 234 | scheduler_d.step() 235 | 236 | if rank == 0: 237 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start))) 238 | 239 | 240 | def main(): 241 | print('Initializing Training Process..') 242 | 243 | parser = argparse.ArgumentParser() 244 | 245 | parser.add_argument('--group_name', default=None) 246 | parser.add_argument('--input_wavs_dir', default='DataBaker') 247 | parser.add_argument('--input_mels_dir', default='ft_dataset') 248 | parser.add_argument('--input_training_file', 249 | default='DataBaker/train.txt') 250 | parser.add_argument('--input_validation_file', 251 | default='DataBaker/val.txt') 252 | parser.add_argument('--checkpoint_path', default='cp_ada_hifigan') 253 | parser.add_argument('--config', default='') 254 | parser.add_argument('--training_epochs', default=31000, type=int) 255 | parser.add_argument('--stdout_interval', default=5, type=int) 256 | parser.add_argument('--checkpoint_interval', default=1000, type=int) 257 | parser.add_argument('--summary_interval', default=10, type=int) 258 | parser.add_argument('--validation_interval', default=100, type=int) 259 | parser.add_argument('--fine_tuning', default=False, type=bool) 260 | 261 | a = parser.parse_args() 262 | 263 | with open(a.config) as f: 264 | data = f.read() 265 | 266 | json_config = json.loads(data) 267 | h = AttrDict(json_config) 268 | build_env(a.config, 'config.json', a.checkpoint_path) 269 | 270 | torch.manual_seed(h.seed) 271 | if torch.cuda.is_available(): 272 | torch.cuda.manual_seed(h.seed) 273 | h.num_gpus = torch.cuda.device_count() 274 | h.batch_size = int(h.batch_size / h.num_gpus) 275 | print('Batch size per GPU :', h.batch_size) 276 | else: 277 | pass 278 | 279 | if h.num_gpus > 1: 280 | mp.spawn(train, nprocs=h.num_gpus, args=(a, h,)) 281 | else: 282 | train(0, a, h) 283 | 284 | 285 | if __name__ == '__main__': 286 | main() 287 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 6 | from utils import init_weights, get_padding 7 | 8 | LRELU_SLOPE = 0.1 9 | 10 | 11 | class ResBlock1(torch.nn.Module): 12 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 13 | super(ResBlock1, self).__init__() 14 | self.h = h 15 | self.convs1 = nn.ModuleList([ 16 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 17 | padding=get_padding(kernel_size, dilation[0]))), 18 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 19 | padding=get_padding(kernel_size, dilation[1]))), 20 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 21 | padding=get_padding(kernel_size, dilation[2]))) 22 | ]) 23 | self.convs1.apply(init_weights) 24 | 25 | self.convs2 = nn.ModuleList([ 26 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 27 | padding=get_padding(kernel_size, 1))), 28 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 29 | padding=get_padding(kernel_size, 1))), 30 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 31 | padding=get_padding(kernel_size, 1))) 32 | ]) 33 | self.convs2.apply(init_weights) 34 | 35 | def forward(self, x): 36 | for c1, c2 in zip(self.convs1, self.convs2): 37 | xt = F.leaky_relu(x, LRELU_SLOPE) 38 | xt = c1(xt) 39 | xt = F.leaky_relu(xt, LRELU_SLOPE) 40 | xt = c2(xt) 41 | x = xt + x 42 | return x 43 | 44 | def remove_weight_norm(self): 45 | for l in self.convs1: 46 | remove_weight_norm(l) 47 | for l in self.convs2: 48 | remove_weight_norm(l) 49 | 50 | 51 | class ResBlock2(torch.nn.Module): 52 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 53 | super(ResBlock2, self).__init__() 54 | self.h = h 55 | self.convs = nn.ModuleList([ 56 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 57 | padding=get_padding(kernel_size, dilation[0]))), 58 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 59 | padding=get_padding(kernel_size, dilation[1]))) 60 | ]) 61 | self.convs.apply(init_weights) 62 | 63 | def forward(self, x): 64 | for c in self.convs: 65 | xt = F.leaky_relu(x, LRELU_SLOPE) 66 | xt = c(xt) 67 | x = xt + x 68 | return x 69 | 70 | def remove_weight_norm(self): 71 | for l in self.convs: 72 | remove_weight_norm(l) 73 | 74 | 75 | 76 | class Generator(torch.nn.Module): 77 | def __init__(self, h): 78 | super(Generator, self).__init__() 79 | self.h = h # config_v1_json 80 | self.num_kernels = len(h.resblock_kernel_sizes) # len([3,7,11]) 81 | self.num_upsamples = len(h.upsample_rates) # len([8,8,2,2]) 82 | self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) # upsample_initial_channel = 512 83 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 # resblock = 1 84 | 85 | self.ups = nn.ModuleList() 86 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): # upsample_kernel_sizes = [16,16,4,4] 87 | self.ups.append(weight_norm( 88 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), 89 | k, u, padding=(k-u)//2))) 90 | self.resblocks = nn.ModuleList() 91 | for i in range(len(self.ups)): 92 | ch = h.upsample_initial_channel//(2**(i+1)) 93 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 94 | self.resblocks.append(resblock(h, ch, k, d)) 95 | 96 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 97 | self.ups.apply(init_weights) 98 | self.conv_post.apply(init_weights) 99 | 100 | def forward(self, x): 101 | x = self.conv_pre(x) 102 | for i in range(self.num_upsamples): 103 | x = F.leaky_relu(x, LRELU_SLOPE) 104 | x = self.ups[i](x) 105 | xs = None 106 | for j in range(self.num_kernels): 107 | if xs is None: 108 | xs = self.resblocks[i*self.num_kernels+j](x) 109 | else: 110 | xs += self.resblocks[i*self.num_kernels+j](x) 111 | x = xs / self.num_kernels 112 | x = F.leaky_relu(x) 113 | x = self.conv_post(x) 114 | x = torch.tanh(x) 115 | return x 116 | 117 | def remove_weight_norm(self): 118 | print('Removing weight norm...') 119 | for l in self.ups: 120 | remove_weight_norm(l) 121 | for l in self.resblocks: 122 | l.remove_weight_norm() 123 | remove_weight_norm(self.conv_pre) 124 | remove_weight_norm(self.conv_post) 125 | 126 | 127 | class Generator_S(torch.nn.Module): 128 | def __init__(self, h): 129 | super(Generator_S, self).__init__() 130 | self.h = h # config_v1_json 131 | self.num_kernels = len(h.resblock_kernel_sizes) # len([3,7,11]) 132 | self.num_upsamples = len(h.upsample_rates) # len([8,8,2,2]) 133 | self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) # upsample_initial_channel = 512 134 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 # resblock = 1 135 | 136 | self.ups = nn.ModuleList() 137 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): # upsample_kernel_sizes = [16,16,4,4] 138 | self.ups.append(weight_norm( 139 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), 140 | k, u, padding=(k-u)//2))) 141 | self.resblocks = nn.ModuleList() 142 | for i in range(len(self.ups)): 143 | ch = h.upsample_initial_channel//(2**(i+1)) 144 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 145 | self.resblocks.append(resblock(h, ch, k, d)) 146 | 147 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 148 | self.ups.apply(init_weights) 149 | self.conv_post.apply(init_weights) 150 | for p in self.parameters(): 151 | p.requires_grad = False 152 | 153 | def forward(self, x): 154 | fmap = [] # 155 | x = self.conv_pre(x) 156 | fmap.append(x) 157 | for i in range(self.num_upsamples): 158 | x = F.leaky_relu(x, LRELU_SLOPE) 159 | x = self.ups[i](x) 160 | xs = None 161 | for j in range(self.num_kernels): 162 | if xs is None: 163 | xs = self.resblocks[i*self.num_kernels+j](x) 164 | else: 165 | xs += self.resblocks[i*self.num_kernels+j](x) 166 | x = xs / self.num_kernels 167 | fmap.append(x) 168 | x = F.leaky_relu(x) 169 | x = self.conv_post(x) 170 | fmap.append(x) 171 | x = torch.tanh(x) 172 | return x, fmap 173 | 174 | def remove_weight_norm(self): 175 | print('Removing weight norm...') 176 | for l in self.ups: 177 | remove_weight_norm(l) 178 | for l in self.resblocks: 179 | l.remove_weight_norm() 180 | remove_weight_norm(self.conv_pre) 181 | remove_weight_norm(self.conv_post) 182 | 183 | 184 | class Generator_T(torch.nn.Module): 185 | def __init__(self, h): 186 | super(Generator_T, self).__init__() 187 | self.h = h # config_v1_json 188 | self.num_kernels = len(h.resblock_kernel_sizes) # len([3,7,11]) 189 | self.num_upsamples = len(h.upsample_rates) # len([8,8,2,2]) 190 | self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) # upsample_initial_channel = 512 191 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 # resblock = 1 192 | 193 | self.ups = nn.ModuleList() 194 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): # upsample_kernel_sizes = [16,16,4,4] 195 | self.ups.append(weight_norm( 196 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), 197 | k, u, padding=(k-u)//2))) 198 | self.resblocks = nn.ModuleList() 199 | for i in range(len(self.ups)): 200 | ch = h.upsample_initial_channel//(2**(i+1)) 201 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 202 | self.resblocks.append(resblock(h, ch, k, d)) 203 | 204 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 205 | self.ups.apply(init_weights) 206 | self.conv_post.apply(init_weights) 207 | 208 | 209 | def forward(self, x): 210 | fmap = [] # 211 | x = self.conv_pre(x) 212 | fmap.append(x) 213 | for i in range(self.num_upsamples): 214 | x = F.leaky_relu(x, LRELU_SLOPE) 215 | x = self.ups[i](x) 216 | xs = None 217 | for j in range(self.num_kernels): 218 | if xs is None: 219 | xs = self.resblocks[i*self.num_kernels+j](x) 220 | else: 221 | xs += self.resblocks[i*self.num_kernels+j](x) 222 | x = xs / self.num_kernels 223 | fmap.append(x) 224 | x = F.leaky_relu(x) 225 | x = self.conv_post(x) 226 | fmap.append(x) 227 | x = torch.tanh(x) 228 | return x, fmap 229 | 230 | def remove_weight_norm(self): 231 | print('Removing weight norm...') 232 | for l in self.ups: 233 | remove_weight_norm(l) 234 | for l in self.resblocks: 235 | l.remove_weight_norm() 236 | remove_weight_norm(self.conv_pre) 237 | remove_weight_norm(self.conv_post) 238 | 239 | 240 | class DiscriminatorP(torch.nn.Module): 241 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 242 | super(DiscriminatorP, self).__init__() 243 | self.period = period 244 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 245 | self.convs = nn.ModuleList([ 246 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 247 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 248 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 249 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 250 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 251 | ]) 252 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 253 | 254 | def forward(self, x): 255 | fmap = [] 256 | b, c, t = x.shape 257 | 258 | if t % self.period != 0: # pad first 259 | n_pad = self.period - (t % self.period) 260 | x = F.pad(x, (0, n_pad), "reflect") 261 | t = t + n_pad 262 | x = x.view(b, c, t // self.period, self.period) 263 | 264 | for l in self.convs: 265 | x = l(x) 266 | x = F.leaky_relu(x, LRELU_SLOPE) 267 | fmap.append(x) 268 | x = self.conv_post(x) 269 | fmap.append(x) 270 | x = torch.flatten(x, 1, -1) 271 | return x, fmap 272 | 273 | 274 | class MultiPeriodDiscriminator(torch.nn.Module): 275 | def __init__(self): 276 | super(MultiPeriodDiscriminator, self).__init__() 277 | self.discriminators = nn.ModuleList([ 278 | DiscriminatorP(2), 279 | DiscriminatorP(3), 280 | DiscriminatorP(5), 281 | DiscriminatorP(7), 282 | DiscriminatorP(11), 283 | ]) 284 | 285 | def forward(self, y, y_hat): 286 | y_d_rs = [] 287 | y_d_gs = [] 288 | fmap_rs = [] 289 | fmap_gs = [] 290 | for i, d in enumerate(self.discriminators): 291 | y_d_r, fmap_r = d(y) 292 | y_d_g, fmap_g = d(y_hat) 293 | y_d_rs.append(y_d_r) 294 | fmap_rs.append(fmap_r) 295 | y_d_gs.append(y_d_g) 296 | fmap_gs.append(fmap_g) 297 | 298 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 299 | 300 | 301 | class DiscriminatorS(torch.nn.Module): 302 | def __init__(self, use_spectral_norm=False): 303 | super(DiscriminatorS, self).__init__() 304 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 305 | 306 | self.convs = nn.ModuleList([ 307 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 308 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 309 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 310 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 311 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 312 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 313 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 314 | ]) 315 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 316 | 317 | def forward(self, x): 318 | fmap = [] 319 | for l in self.convs: 320 | x = l(x) 321 | x = F.leaky_relu(x, LRELU_SLOPE) 322 | fmap.append(x) 323 | x = self.conv_post(x) 324 | fmap.append(x) 325 | x = torch.flatten(x, 1, -1) 326 | return x, fmap 327 | 328 | 329 | class MultiScaleDiscriminator(torch.nn.Module): 330 | def __init__(self): 331 | super(MultiScaleDiscriminator, self).__init__() 332 | self.discriminators = nn.ModuleList([ 333 | DiscriminatorS(use_spectral_norm=True), 334 | DiscriminatorS(), 335 | DiscriminatorS(), 336 | ]) 337 | self.meanpools = nn.ModuleList([ 338 | AvgPool1d(4, 2, padding=2), 339 | AvgPool1d(4, 2, padding=2) 340 | ]) 341 | 342 | def forward(self, y, y_hat): 343 | y_d_rs = [] 344 | y_d_gs = [] 345 | fmap_rs = [] 346 | fmap_gs = [] 347 | for i, d in enumerate(self.discriminators): 348 | if i != 0: 349 | y = self.meanpools[i-1](y) 350 | y_hat = self.meanpools[i-1](y_hat) 351 | y_d_r, fmap_r = d(y) 352 | y_d_g, fmap_g = d(y_hat) 353 | y_d_rs.append(y_d_r) 354 | fmap_rs.append(fmap_r) 355 | y_d_gs.append(y_d_g) 356 | fmap_gs.append(fmap_g) 357 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 358 | 359 | 360 | def feature_loss(fmap_r, fmap_g): 361 | loss = 0 362 | for dr, dg in zip(fmap_r, fmap_g): 363 | for rl, gl in zip(dr, dg): 364 | loss += torch.mean(torch.abs(rl - gl)) 365 | 366 | return loss*2 367 | 368 | 369 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 370 | ''' 371 | LS-GAN 372 | Xudong Mao, Qing Li, Haoran Xie, Raymond YK Lau, ZhenWang, and Stephen Paul Smolley. 373 | Least squares generative adversarial networks. 374 | In Proceedings of the IEEE International Conference on Computer Vision, pages 2794–2802, 2017. 375 | 376 | The discriminator is trained to classify ground truth samples to 1, and the samples synthesized from 377 | the generator to 0. The generator is trained to fake the discriminator by updating the sample quality 378 | to be classified to a value almost equal to 1. 379 | ''' 380 | 381 | loss = 0 382 | r_losses = [] 383 | g_losses = [] 384 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 385 | r_loss = torch.mean((1-dr)**2) 386 | g_loss = torch.mean(dg**2) 387 | loss += (r_loss + g_loss) 388 | r_losses.append(r_loss.item()) 389 | g_losses.append(g_loss.item()) 390 | 391 | return loss, r_losses, g_losses 392 | 393 | 394 | def generator_loss(disc_outputs): 395 | loss = 0 396 | gen_losses = [] 397 | for dg in disc_outputs: 398 | l = torch.mean((1-dg)**2) 399 | gen_losses.append(l) 400 | loss += l 401 | 402 | return loss, gen_losses 403 | 404 | 405 | def ada_loss(fmap_s, fmap_t): 406 | loss = 0 407 | ada_losses = [] 408 | for s, t in zip(fmap_s, fmap_t): 409 | l_losses = 0 410 | s = s.reshape(s.shape[0], -1) 411 | t = t.reshape(t.shape[0], -1) 412 | cs_s = torch.matmul(s, s.T) / torch.matmul(torch.linalg.norm(s, axis=1).reshape(-1, 1), 413 | torch.linalg.norm(s, axis=1).reshape(1, -1)) 414 | cs_t = torch.matmul(t, t.T) / torch.matmul(torch.linalg.norm(t, axis=1).reshape(-1, 1), 415 | torch.linalg.norm(t, axis=1).reshape(1, -1)) 416 | # calculate kl-dist 417 | for i, (cs_s_i, cs_t_i) in enumerate(zip(cs_s, cs_t)): 418 | cs_s_i = torch.softmax(torch.cat((cs_s_i[:i], cs_s_i[i+1:]), 0), 0) 419 | cs_t_i = torch.softmax(torch.cat((cs_t_i[:i], cs_t_i[i+1:]), 0), 0) 420 | l_losses += F.kl_div(cs_t_i.log(), cs_s_i, reduction='mean') 421 | l_losses /= cs_s.shape[0] 422 | loss += l_losses 423 | ada_losses.append(l_losses) 424 | 425 | return loss, ada_losses 426 | 427 | --------------------------------------------------------------------------------