├── DataBaker
├── val.txt
├── 000001.wav
├── 000002.wav
├── 000003.wav
├── 000004.wav
├── 000005.wav
├── 000006.wav
├── 000007.wav
├── 000008.wav
├── 000009.wav
├── 000010.wav
└── train.txt
├── export_onnx.py
├── test_files
└── 000010.wav
├── generated_files
├── 000010.wav
├── hifi-gan000010_generated.wav
└── adahifi-gan000010_generated.wav
├── requirements.txt
├── env.py
├── config_v1.json
├── ada_config_v1.json
├── LICENSE
├── utils.py
├── preprocess_aishell3.py
├── inference_e2e.py
├── AISHELL-3
└── val.txt
├── inference.py
├── README.md
├── meldataset.py
├── train_hifi_gan.py
├── train_ada_hifi_gan.py
└── models.py
/DataBaker/val.txt:
--------------------------------------------------------------------------------
1 | 000010|
2 |
--------------------------------------------------------------------------------
/export_onnx.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/DataBaker/000001.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000001.wav
--------------------------------------------------------------------------------
/DataBaker/000002.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000002.wav
--------------------------------------------------------------------------------
/DataBaker/000003.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000003.wav
--------------------------------------------------------------------------------
/DataBaker/000004.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000004.wav
--------------------------------------------------------------------------------
/DataBaker/000005.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000005.wav
--------------------------------------------------------------------------------
/DataBaker/000006.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000006.wav
--------------------------------------------------------------------------------
/DataBaker/000007.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000007.wav
--------------------------------------------------------------------------------
/DataBaker/000008.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000008.wav
--------------------------------------------------------------------------------
/DataBaker/000009.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000009.wav
--------------------------------------------------------------------------------
/DataBaker/000010.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/DataBaker/000010.wav
--------------------------------------------------------------------------------
/test_files/000010.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/test_files/000010.wav
--------------------------------------------------------------------------------
/generated_files/000010.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/generated_files/000010.wav
--------------------------------------------------------------------------------
/DataBaker/train.txt:
--------------------------------------------------------------------------------
1 | 000001|
2 | 000002|
3 | 000003|
4 | 000004|
5 | 000005|
6 | 000006|
7 | 000007|
8 | 000008|
9 | 000009|
10 |
--------------------------------------------------------------------------------
/generated_files/hifi-gan000010_generated.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/generated_files/hifi-gan000010_generated.wav
--------------------------------------------------------------------------------
/generated_files/adahifi-gan000010_generated.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuan1615/AdaVocoder/HEAD/generated_files/adahifi-gan000010_generated.wav
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.4.0
2 | numpy==1.17.4
3 | librosa==0.7.2
4 | scipy==1.4.1
5 | tensorboard==2.0
6 | soundfile==0.10.3.post1
7 | matplotlib==3.1.3
--------------------------------------------------------------------------------
/env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 |
5 | class AttrDict(dict):
6 | def __init__(self, *args, **kwargs):
7 | super(AttrDict, self).__init__(*args, **kwargs)
8 | self.__dict__ = self
9 |
10 |
11 | def build_env(config, config_name, path):
12 | t_path = os.path.join(path, config_name)
13 | if config != t_path:
14 | os.makedirs(path, exist_ok=True)
15 | shutil.copyfile(config, os.path.join(path, config_name))
16 |
--------------------------------------------------------------------------------
/config_v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "1",
3 | "num_gpus": 0,
4 | "batch_size": 32,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,8,2,2],
12 | "upsample_kernel_sizes": [16,16,4,4],
13 | "upsample_initial_channel": 512,
14 | "resblock_kernel_sizes": [3,7,11],
15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16 |
17 | "segment_size": 8192,
18 | "num_mels": 80,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 256,
22 | "win_size": 1024,
23 |
24 | "sampling_rate": 22050,
25 |
26 | "fmin": 0,
27 | "fmax": 8000,
28 | "fmax_for_loss": null,
29 |
30 | "num_workers": 4,
31 |
32 | "dist_config": {
33 | "dist_backend": "nccl",
34 | "dist_url": "tcp://localhost:54321",
35 | "world_size": 1
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/ada_config_v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "1",
3 | "num_gpus": 0,
4 | "batch_size": 8,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,8,2,2],
12 | "upsample_kernel_sizes": [16,16,4,4],
13 | "upsample_initial_channel": 512,
14 | "resblock_kernel_sizes": [3,7,11],
15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16 |
17 | "segment_size": 8192,
18 | "num_mels": 80,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 256,
22 | "win_size": 1024,
23 |
24 | "sampling_rate": 22050,
25 |
26 | "fmin": 0,
27 | "fmax": 8000,
28 | "fmax_for_loss": null,
29 |
30 | "num_workers": 4,
31 |
32 | "dist_config": {
33 | "dist_backend": "nccl",
34 | "dist_url": "tcp://localhost:54321",
35 | "world_size": 1
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Xin Yuan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import matplotlib
4 | import torch
5 | from torch.nn.utils import weight_norm
6 | matplotlib.use("Agg")
7 | import matplotlib.pylab as plt
8 |
9 |
10 | def plot_spectrogram(spectrogram):
11 | fig, ax = plt.subplots(figsize=(10, 2))
12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13 | interpolation='none')
14 | plt.colorbar(im, ax=ax)
15 |
16 | fig.canvas.draw()
17 | plt.close()
18 |
19 | return fig
20 |
21 |
22 | def init_weights(m, mean=0.0, std=0.01):
23 | classname = m.__class__.__name__
24 | if classname.find("Conv") != -1:
25 | m.weight.data.normal_(mean, std)
26 |
27 |
28 | def apply_weight_norm(m):
29 | classname = m.__class__.__name__
30 | if classname.find("Conv") != -1:
31 | weight_norm(m)
32 |
33 |
34 | def get_padding(kernel_size, dilation=1):
35 | return int((kernel_size*dilation - dilation)/2)
36 |
37 |
38 | def load_checkpoint(filepath, device):
39 | assert os.path.isfile(filepath)
40 | print("Loading '{}'".format(filepath))
41 | checkpoint_dict = torch.load(filepath, map_location=device)
42 | print("Complete.")
43 | return checkpoint_dict
44 |
45 |
46 | def save_checkpoint(filepath, obj):
47 | print("Saving checkpoint to {}".format(filepath))
48 | torch.save(obj, filepath)
49 | print("Complete.")
50 |
51 |
52 | def scan_checkpoint(cp_dir, prefix):
53 | pattern = os.path.join(cp_dir, prefix + '????????')
54 | cp_list = glob.glob(pattern)
55 | if len(cp_list) == 0:
56 | return None
57 | return sorted(cp_list)[-1]
58 |
59 |
--------------------------------------------------------------------------------
/preprocess_aishell3.py:
--------------------------------------------------------------------------------
1 | # import os
2 | # from shutil import copyfile
3 | # import librosa
4 | # from scipy.io import wavfile
5 | # import numpy as np
6 | # from tqdm import tqdm
7 | #
8 | # save_path = '/home/admin/yuanxin/2.TTSData/AISHELL-3/wavs'
9 | #
10 | # train_path = '/home/admin/yuanxin/2.TTSData/AISHELL-3/train/wav'
11 | # train_file = '/home/admin/yuanxin/2.TTSData/AISHELL-3/train/content.txt'
12 | # test_path = '/home/admin/yuanxin/2.TTSData/AISHELL-3/test/wav'
13 | # test_file = '/home/admin/yuanxin/2.TTSData/AISHELL-3/test/content.txt'
14 | #
15 | #
16 | # file_list = os.listdir(train_path)
17 | # for speaker_name in file_list:
18 | # wave_list = os.listdir(os.path.join(train_path, speaker_name))
19 | # for wavname in wave_list:
20 | # copyfile(os.path.join(train_path, speaker_name, wavname), os.path.join(save_path, wavname))
21 | #
22 | # file_list = os.listdir(test_path)
23 | # for speaker_name in file_list:
24 | # wave_list = os.listdir(os.path.join(test_path, speaker_name))
25 | # for wavname in wave_list:
26 | # copyfile(os.path.join(test_path, speaker_name, wavname), os.path.join(save_path, wavname))
27 | #
28 |
29 | # with open(train_file, 'r', encoding='utf-8') as f:
30 | # lines = f.readlines()
31 | # with open('AISHELL-3/train.txt', 'w', encoding='utf-8') as f1:
32 | # for l in lines:
33 | # l = l.split('.')[0] + '|'
34 | # f1.write(l + '\n')
35 | #
36 | # with open(test_file, 'r', encoding='utf-8') as f:
37 | # lines = f.readlines()
38 | # with open('AISHELL-3/val.txt', 'w', encoding='utf-8') as f1:
39 | # for l in lines:
40 | # l = l.split('.')[0] + '|'
41 | # f1.write(l + '\n')
42 | #
43 |
44 | # file_list = os.listdir(save_path)
45 | # for file in tqdm(file_list):
46 | # wav, _ = librosa.load(os.path.join(save_path, file), 22050)
47 | # wav = wav / max(abs(wav)) * 32767.0
48 | # wavfile.write(
49 | # os.path.join(save_path, file),
50 | # 22050,
51 | # wav.astype(np.int16),
52 | # )
53 |
--------------------------------------------------------------------------------
/inference_e2e.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 |
3 | import glob
4 | import os
5 | import numpy as np
6 | import argparse
7 | import json
8 | import torch
9 | from scipy.io.wavfile import write
10 | from env import AttrDict
11 | from meldataset import MAX_WAV_VALUE
12 | from models import Generator
13 |
14 | h = None
15 | device = None
16 |
17 |
18 | def load_checkpoint(filepath, device):
19 | assert os.path.isfile(filepath)
20 | print("Loading '{}'".format(filepath))
21 | checkpoint_dict = torch.load(filepath, map_location=device)
22 | print("Complete.")
23 | return checkpoint_dict
24 |
25 |
26 | def scan_checkpoint(cp_dir, prefix):
27 | pattern = os.path.join(cp_dir, prefix + '*')
28 | cp_list = glob.glob(pattern)
29 | if len(cp_list) == 0:
30 | return ''
31 | return sorted(cp_list)[-1]
32 |
33 |
34 | def inference(a):
35 | generator = Generator(h).to(device)
36 |
37 | state_dict_g = load_checkpoint(a.checkpoint_file, device)
38 | generator.load_state_dict(state_dict_g['generator'])
39 |
40 | filelist = os.listdir(a.input_mels_dir)
41 |
42 | os.makedirs(a.output_dir, exist_ok=True)
43 |
44 | generator.eval()
45 | generator.remove_weight_norm()
46 | with torch.no_grad():
47 | for i, filname in enumerate(filelist):
48 | x = np.load(os.path.join(a.input_mels_dir, filname))
49 | x = torch.FloatTensor(x).to(device)
50 | y_g_hat = generator(x)
51 | audio = y_g_hat.squeeze()
52 | audio = audio * MAX_WAV_VALUE
53 | audio = audio.cpu().numpy().astype('int16')
54 |
55 | output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated_e2e.wav')
56 | write(output_file, h.sampling_rate, audio)
57 | print(output_file)
58 |
59 |
60 | def main():
61 | print('Initializing Inference Process..')
62 |
63 | parser = argparse.ArgumentParser()
64 | parser.add_argument('--input_mels_dir', default='test_mel_files')
65 | parser.add_argument('--output_dir', default='generated_files_from_mel')
66 | parser.add_argument('--checkpoint_file', required=True)
67 | a = parser.parse_args()
68 |
69 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json')
70 | with open(config_file) as f:
71 | data = f.read()
72 |
73 | global h
74 | json_config = json.loads(data)
75 | h = AttrDict(json_config)
76 |
77 | torch.manual_seed(h.seed)
78 | global device
79 | if torch.cuda.is_available():
80 | torch.cuda.manual_seed(h.seed)
81 | device = torch.device('cuda')
82 | else:
83 | device = torch.device('cpu')
84 |
85 | inference(a)
86 |
87 |
88 | if __name__ == '__main__':
89 | main()
90 |
91 |
--------------------------------------------------------------------------------
/AISHELL-3/val.txt:
--------------------------------------------------------------------------------
1 | SSB06930002|
2 | SSB06930003|
3 | SSB06930004|
4 | SSB06930005|
5 | SSB06930006|
6 | SSB06930007|
7 | SSB06930008|
8 | SSB06930010|
9 | SSB06930011|
10 | SSB06930012|
11 | SSB06930013|
12 | SSB06930014|
13 | SSB06930015|
14 | SSB06930016|
15 | SSB06930017|
16 | SSB06930018|
17 | SSB06930019|
18 | SSB06930021|
19 | SSB06930022|
20 | SSB06930023|
21 | SSB06930024|
22 | SSB06930025|
23 | SSB06930026|
24 | SSB06930027|
25 | SSB06930028|
26 | SSB06930029|
27 | SSB06930030|
28 | SSB06930031|
29 | SSB06930032|
30 | SSB06930033|
31 | SSB06930034|
32 | SSB06930035|
33 | SSB06930036|
34 | SSB06930037|
35 | SSB06930039|
36 | SSB06930040|
37 | SSB06930041|
38 | SSB06930042|
39 | SSB06930044|
40 | SSB06930045|
41 | SSB06930047|
42 | SSB06930048|
43 | SSB06930049|
44 | SSB06930051|
45 | SSB06930053|
46 | SSB06930054|
47 | SSB06930056|
48 | SSB06930057|
49 | SSB06930058|
50 | SSB06930059|
51 | SSB06930060|
52 | SSB06930062|
53 | SSB06930063|
54 | SSB06930064|
55 | SSB06930065|
56 | SSB06930066|
57 | SSB06930067|
58 | SSB06930068|
59 | SSB06930070|
60 | SSB06930071|
61 | SSB06930072|
62 | SSB06930074|
63 | SSB06930075|
64 | SSB06930076|
65 | SSB06930077|
66 | SSB06930078|
67 | SSB06930079|
68 | SSB06930080|
69 | SSB06930020|
70 | SSB06930038|
71 | SSB06930061|
72 | SSB06930081|
73 | SSB06930099|
74 | SSB06930120|
75 | SSB06930141|
76 | SSB06930160|
77 | SSB06930181|
78 | SSB06930201|
79 | SSB06930219|
80 | SSB06930238|
81 | SSB06930257|
82 | SSB06930277|
83 | SSB06930295|
84 | SSB06930314|
85 | SSB06930335|
86 | SSB06930354|
87 | SSB06930374|
88 | SSB06930392|
89 | SSB06930412|
90 | SSB06930432|
91 | SSB06930451|
92 | SSB06930082|
93 | SSB06930083|
94 | SSB06930084|
95 | SSB06930085|
96 | SSB06930086|
97 | SSB06930087|
98 | SSB06930088|
99 | SSB06930089|
100 | SSB06930090|
101 | SSB06930091|
102 | SSB06930092|
103 | SSB06930093|
104 | SSB06930094|
105 | SSB06930095|
106 | SSB06930096|
107 | SSB06930097|
108 | SSB06930098|
109 | SSB06930100|
110 | SSB06930101|
111 | SSB06930102|
112 | SSB06930104|
113 | SSB06930105|
114 | SSB06930106|
115 | SSB06930108|
116 | SSB06930109|
117 | SSB06930110|
118 | SSB06930111|
119 | SSB06930112|
120 | SSB06930113|
121 | SSB06930114|
122 | SSB06930115|
123 | SSB06930117|
124 | SSB06930118|
125 | SSB06930119|
126 | SSB06930121|
127 | SSB06930122|
128 | SSB06930123|
129 | SSB06930124|
130 | SSB06930126|
131 | SSB06930127|
132 | SSB06930128|
133 | SSB06930129|
134 | SSB06930130|
135 | SSB06930131|
136 | SSB06930132|
137 | SSB06930134|
138 | SSB06930135|
139 | SSB06930136|
140 | SSB06930137|
141 | SSB06930138|
142 | SSB06930140|
143 | SSB06930142|
144 | SSB06930143|
145 | SSB06930144|
146 | SSB06930145|
147 | SSB06930146|
148 | SSB06930147|
149 | SSB06930148|
150 | SSB06930149|
151 | SSB06930150|
152 | SSB06930151|
153 | SSB06930152|
154 | SSB06930153|
155 | SSB06930154|
156 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 |
3 | import glob
4 | import os
5 | import argparse
6 | import json
7 |
8 | import numpy as np
9 | import torch
10 | from scipy.io.wavfile import write
11 | from env import AttrDict
12 | from meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav
13 | from models import Generator
14 |
15 | h = None
16 | device = None
17 |
18 |
19 | def load_checkpoint(filepath, device):
20 | assert os.path.isfile(filepath)
21 | print("Loading '{}'".format(filepath))
22 | checkpoint_dict = torch.load(filepath, map_location=device)
23 | print("Complete.")
24 | return checkpoint_dict
25 |
26 |
27 | def get_mel(x):
28 | return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
29 |
30 |
31 | def scan_checkpoint(cp_dir, prefix):
32 | pattern = os.path.join(cp_dir, prefix + '*')
33 | cp_list = glob.glob(pattern)
34 | if len(cp_list) == 0:
35 | return ''
36 | return sorted(cp_list)[-1]
37 |
38 |
39 | def inference(a):
40 | generator = Generator(h).to(device)
41 |
42 | state_dict_g = load_checkpoint(a.checkpoint_file, device)
43 | generator.load_state_dict(state_dict_g['generator'])
44 |
45 | filelist = os.listdir(a.input_wavs_dir)
46 |
47 | os.makedirs(a.output_dir, exist_ok=True)
48 |
49 | generator.eval()
50 | generator.remove_weight_norm()
51 | with torch.no_grad():
52 | for i, filname in enumerate(filelist):
53 | wav, sr = load_wav(os.path.join(a.input_wavs_dir, filname))
54 | wav = wav / MAX_WAV_VALUE
55 | wav = torch.FloatTensor(wav).to(device)
56 | x = get_mel(wav.unsqueeze(0))
57 | y_g_hat = generator(x)
58 | audio = y_g_hat.squeeze()
59 | audio = audio * MAX_WAV_VALUE
60 | audio = audio.cpu().numpy().astype('int16')
61 |
62 | output_file = os.path.join(a.output_dir, a.model_name + os.path.splitext(filname)[0] + '_generated.wav')
63 | write(output_file, h.sampling_rate, audio)
64 | print(output_file)
65 |
66 |
67 | def main():
68 | print('Initializing Inference Process..')
69 |
70 | parser = argparse.ArgumentParser()
71 | parser.add_argument('--input_wavs_dir', default='test_files')
72 | parser.add_argument('--output_dir', default='generated_files')
73 | parser.add_argument('--model_name', required=True)
74 | parser.add_argument('--checkpoint_file', required=True)
75 | a = parser.parse_args()
76 |
77 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json')
78 | with open(config_file) as f:
79 | data = f.read()
80 |
81 | global h
82 | json_config = json.loads(data)
83 | h = AttrDict(json_config)
84 |
85 | torch.manual_seed(h.seed)
86 | global device
87 | if torch.cuda.is_available():
88 | torch.cuda.manual_seed(h.seed)
89 | device = torch.device('cuda')
90 | else:
91 | device = torch.device('cpu')
92 |
93 | inference(a)
94 |
95 |
96 | if __name__ == '__main__':
97 | main()
98 |
99 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AdaVocoder: Adaptive Vocoder for Custom Voice
2 |
3 | In our [paper](https://www.isca-speech.org/archive/interspeech_2022/yuan22_interspeech.html),
4 | we proposed AdaVocoder: Adaptive Vocoder for Custom Voice.
5 | We provide our implementation and pretrained models for `AdaHiFi-GAN` as open source in this repository.
6 |
7 | **Abstract :**
8 |
9 | Custom voice is to construct a personal speech synthesis system by adapting the source speech synthesis model
10 | to the target model through the target few recordings. The solution to constructing a custom voice is
11 | to combine an adaptive acoustic model with a robust vocoder. However, training a robust vocoder usually requires a multi-speaker dataset,
12 | which should include various age groups and various timbres, so that the trained vocoder can be used for unseen speakers. Collecting such a
13 | multi-speaker dataset is difficult, and the dataset distribution always has a mismatch with the distribution of the target speaker dataset.
14 |
15 | This paper proposes an adaptive vocoder for custom voice from another novel perspective to solve the above problems.
16 | The adaptive vocoder mainly uses a cross-domain consistency loss to solve the overfitting problem encountered by the GAN-based neural
17 | vocoder in the transfer learning of few-shot scenes. We construct two adaptive vocoders, AdaMelGAN and AdaHiFi-GAN.
18 | First, We pre-train the source vocoder model on AISHELL3 and CSMSC datasets, respectively.
19 | Then, fine-tune it on the internal dataset VXI-children with few adaptation data.
20 | The empirical results show that a high-quality custom voice system can be built by combining a adaptive acoustic model with a adaptive vocoder.
21 |
22 | ## Pre-requisites
23 | 1. Python >= 3.6
24 | 2. Clone this repository.
25 | 3. Install python requirements. Please refer [requirements.txt](requirements.txt)
26 | 4. Download and extract the [AISHELL3 dataset](http://www.aishelltech.com/aishell_3), then rename or create a link to the dataset folder: `ln -s /path/to/AISHELL-3/wavs DUMMY1`
27 | And move all wav files to `AISHELL-3/wavs`, and sample all audio files to `22050`Hz.
28 |
29 |
30 | ## Training HiFi-GAN
31 | ```
32 | python train_hifi_gan.py --config config_v1.json
33 | ```
34 | - Tensorboard
35 | ```
36 | tensorboard --logdir cp_hifigan/logs/ --bind_all
37 | ```
38 |
39 | Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
40 | You can change the path by adding `--checkpoint_path` option.
41 |
42 | ## Pretrained Model
43 | You can also use pretrained models we provide.
44 | [Download AISHELL3 pretrained models](https://drive.google.com/file/d/1lqp-8mQIultA2nQ9lY3SNyUpqDpZTHLk/view?usp=sharing)
45 |
46 |
47 | ## Training AdaHiFi-GAN
48 | First you need to save the pre-trained `AISHELL-3` model to `cp_ada_hifigan`.
49 |
50 | Due to the need for confidentiality, VXI-children is not used here. I tested it using the child sample shared by [Data-Baker](https://www.data-baker.com/).
51 |
52 | ```
53 | python train_ada_hifi_gan.py --config config_v1.json
54 | ```
55 | - Tensorboard
56 | ```
57 | tensorboard --logdir cp_ada_hifigan/logs/ --bind_all
58 | ```
59 | Checkpoints and copy of the configuration file are saved in `cp_ada_hifigan` directory by default.
60 | You can change the path by adding `--checkpoint_path` option.
61 |
62 |
63 | ## Inference from wav file
64 | 1. Make `test_files` directory and copy wav files into the directory.
65 | 2. Run the following command.
66 | ```
67 | python inference.py --checkpoint_file [generator checkpoint file path] --model_name [hifi-gan or adahifi-gan]
68 | ```
69 | Generated wav files are saved in `generated_files` by default.
70 | You can change the path by adding `--output_dir` option.
71 |
72 | ## [Some Sample](https://yuan1615.github.io/2022/09/21/AdaVocoder/)
73 |
74 |
75 | ## Acknowledgements
76 | We referred to [HiFi-GAN](https://github.com/jik876/hifi-gan) to implement this.
77 |
78 |
--------------------------------------------------------------------------------
/meldataset.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | import torch
5 | import torch.utils.data
6 | import numpy as np
7 | from librosa.util import normalize
8 | from scipy.io.wavfile import read
9 | from librosa.filters import mel as librosa_mel_fn
10 |
11 | MAX_WAV_VALUE = 32768.0
12 |
13 |
14 | def load_wav(full_path):
15 | sampling_rate, data = read(full_path)
16 | return data, sampling_rate
17 |
18 |
19 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
20 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
21 |
22 |
23 | def dynamic_range_decompression(x, C=1):
24 | return np.exp(x) / C
25 |
26 |
27 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
28 | return torch.log(torch.clamp(x, min=clip_val) * C)
29 |
30 |
31 | def dynamic_range_decompression_torch(x, C=1):
32 | return torch.exp(x) / C
33 |
34 |
35 | def spectral_normalize_torch(magnitudes):
36 | output = dynamic_range_compression_torch(magnitudes)
37 | return output
38 |
39 |
40 | def spectral_de_normalize_torch(magnitudes):
41 | output = dynamic_range_decompression_torch(magnitudes)
42 | return output
43 |
44 |
45 | mel_basis = {}
46 | hann_window = {}
47 |
48 |
49 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
50 | if torch.min(y) < -1.:
51 | print('min value is ', torch.min(y))
52 | if torch.max(y) > 1.:
53 | print('max value is ', torch.max(y))
54 |
55 | global mel_basis, hann_window
56 | if fmax not in mel_basis:
57 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
58 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
59 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
60 |
61 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
62 | y = y.squeeze(1)
63 |
64 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
65 | center=center, pad_mode='reflect', normalized=False, onesided=True)
66 |
67 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
68 |
69 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
70 | spec = spectral_normalize_torch(spec)
71 |
72 | return spec
73 |
74 |
75 | def get_dataset_filelist(a):
76 | with open(a.input_training_file, 'r', encoding='utf-8') as fi:
77 | training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
78 | for x in fi.read().split('\n') if len(x) > 0]
79 |
80 | with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
81 | validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
82 | for x in fi.read().split('\n') if len(x) > 0]
83 | return training_files, validation_files
84 |
85 |
86 | class MelDataset(torch.utils.data.Dataset):
87 | def __init__(self, training_files, segment_size, n_fft, num_mels,
88 | hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
89 | device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
90 | self.audio_files = training_files
91 | random.seed(1234)
92 | if shuffle:
93 | random.shuffle(self.audio_files)
94 | self.segment_size = segment_size
95 | self.sampling_rate = sampling_rate
96 | self.split = split
97 | self.n_fft = n_fft
98 | self.num_mels = num_mels
99 | self.hop_size = hop_size
100 | self.win_size = win_size
101 | self.fmin = fmin
102 | self.fmax = fmax
103 | self.fmax_loss = fmax_loss
104 | self.cached_wav = None
105 | self.n_cache_reuse = n_cache_reuse
106 | self._cache_ref_count = 0
107 | self.device = device
108 | self.fine_tuning = fine_tuning
109 | self.base_mels_path = base_mels_path
110 |
111 | def __getitem__(self, index):
112 | filename = self.audio_files[index]
113 | if self._cache_ref_count == 0:
114 | audio, sampling_rate = load_wav(filename)
115 | audio = audio / MAX_WAV_VALUE
116 | if not self.fine_tuning:
117 | audio = normalize(audio) * 0.95
118 | self.cached_wav = audio
119 | if sampling_rate != self.sampling_rate:
120 | raise ValueError("{} SR doesn't match target {} SR".format(
121 | sampling_rate, self.sampling_rate))
122 | self._cache_ref_count = self.n_cache_reuse
123 | else:
124 | audio = self.cached_wav
125 | self._cache_ref_count -= 1
126 |
127 | audio = torch.FloatTensor(audio)
128 | audio = audio.unsqueeze(0)
129 |
130 | if not self.fine_tuning:
131 | if self.split:
132 | if audio.size(1) >= self.segment_size:
133 | max_audio_start = audio.size(1) - self.segment_size
134 | audio_start = random.randint(0, max_audio_start)
135 | audio = audio[:, audio_start:audio_start+self.segment_size]
136 | else:
137 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
138 |
139 | mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
140 | self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
141 | center=False)
142 | else:
143 | mel = np.load(
144 | os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
145 | mel = torch.from_numpy(mel)
146 |
147 | if len(mel.shape) < 3:
148 | mel = mel.unsqueeze(0)
149 |
150 | if self.split:
151 | frames_per_seg = math.ceil(self.segment_size / self.hop_size)
152 |
153 | if audio.size(1) >= self.segment_size:
154 | mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
155 | mel = mel[:, :, mel_start:mel_start + frames_per_seg]
156 | audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
157 | else:
158 | mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
159 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
160 |
161 | mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
162 | self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
163 | center=False)
164 |
165 | return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
166 |
167 | def __len__(self):
168 | return len(self.audio_files)
169 |
--------------------------------------------------------------------------------
/train_hifi_gan.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | warnings.simplefilter(action='ignore', category=FutureWarning)
4 | import itertools
5 | import os
6 | import time
7 | import argparse
8 | import json
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.utils.tensorboard import SummaryWriter
12 | from torch.utils.data import DistributedSampler, DataLoader
13 | import torch.multiprocessing as mp
14 | from torch.distributed import init_process_group
15 | from torch.nn.parallel import DistributedDataParallel
16 | from env import AttrDict, build_env
17 | from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
18 | from models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss, \
19 | discriminator_loss
20 | from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
21 |
22 | torch.backends.cudnn.benchmark = True
23 |
24 |
25 | def train(rank, a, h):
26 | if h.num_gpus > 1:
27 | init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
28 | world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
29 |
30 | torch.cuda.manual_seed(h.seed)
31 | device = torch.device('cuda:{:d}'.format(rank))
32 |
33 | generator = Generator(h).to(device)
34 | mpd = MultiPeriodDiscriminator().to(device)
35 | msd = MultiScaleDiscriminator().to(device)
36 |
37 | if rank == 0:
38 | print(generator)
39 | os.makedirs(a.checkpoint_path, exist_ok=True)
40 | print("checkpoints directory : ", a.checkpoint_path)
41 |
42 | if os.path.isdir(a.checkpoint_path):
43 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
44 | cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
45 |
46 | steps = 0
47 | if cp_g is None or cp_do is None:
48 | state_dict_do = None
49 | last_epoch = -1
50 | else:
51 | state_dict_g = load_checkpoint(cp_g, device)
52 | state_dict_do = load_checkpoint(cp_do, device)
53 | generator.load_state_dict(state_dict_g['generator'])
54 | mpd.load_state_dict(state_dict_do['mpd'])
55 | msd.load_state_dict(state_dict_do['msd'])
56 | steps = state_dict_do['steps'] + 1
57 | last_epoch = state_dict_do['epoch']
58 |
59 | if h.num_gpus > 1:
60 | generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
61 | mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
62 | msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
63 |
64 | optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
65 | optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
66 | h.learning_rate, betas=[h.adam_b1, h.adam_b2])
67 |
68 | if state_dict_do is not None:
69 | optim_g.load_state_dict(state_dict_do['optim_g'])
70 | optim_d.load_state_dict(state_dict_do['optim_d'])
71 |
72 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
73 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
74 |
75 | training_filelist, validation_filelist = get_dataset_filelist(a)
76 |
77 | trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
78 | h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
79 | shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
80 | fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
81 |
82 | train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
83 |
84 | train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
85 | sampler=train_sampler,
86 | batch_size=h.batch_size,
87 | pin_memory=True,
88 | drop_last=True)
89 |
90 | if rank == 0:
91 | validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
92 | h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
93 | fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
94 | base_mels_path=a.input_mels_dir)
95 | validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
96 | sampler=None,
97 | batch_size=1,
98 | pin_memory=True,
99 | drop_last=True)
100 |
101 | sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
102 |
103 | generator.train()
104 | mpd.train()
105 | msd.train()
106 | for epoch in range(max(0, last_epoch), a.training_epochs):
107 | if rank == 0:
108 | start = time.time()
109 | print("Epoch: {}".format(epoch + 1))
110 |
111 | if h.num_gpus > 1:
112 | train_sampler.set_epoch(epoch)
113 |
114 | for i, batch in enumerate(train_loader):
115 | if rank == 0:
116 | start_b = time.time()
117 | x, y, _, y_mel = batch
118 | x = torch.autograd.Variable(x.to(device, non_blocking=True))
119 | y = torch.autograd.Variable(y.to(device, non_blocking=True))
120 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
121 | y = y.unsqueeze(1)
122 | # print(y.shape)
123 | # y.shape: [16, 1, 8192]
124 | y_g_hat = generator(x)
125 | y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size,
126 | h.win_size,
127 | h.fmin, h.fmax_for_loss)
128 |
129 | optim_d.zero_grad()
130 |
131 | # MPD
132 | y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
133 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
134 |
135 | # MSD
136 | y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
137 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
138 |
139 | loss_disc_all = loss_disc_s + loss_disc_f
140 |
141 | loss_disc_all.backward()
142 | optim_d.step()
143 |
144 | # Generator
145 | optim_g.zero_grad()
146 |
147 | # L1 Mel-Spectrogram Loss
148 | loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
149 |
150 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
151 | y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
152 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
153 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
154 | loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
155 | loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
156 | loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
157 |
158 | loss_gen_all.backward()
159 | optim_g.step()
160 |
161 | if rank == 0:
162 | # STDOUT logging
163 | if steps % a.stdout_interval == 0:
164 | with torch.no_grad():
165 | mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
166 |
167 | print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
168 | format(steps, loss_gen_all, mel_error, time.time() - start_b))
169 |
170 | # checkpointing
171 | if steps % a.checkpoint_interval == 0 and steps != 0:
172 | checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
173 | save_checkpoint(checkpoint_path,
174 | {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
175 | checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
176 | save_checkpoint(checkpoint_path,
177 | {'mpd': (mpd.module if h.num_gpus > 1
178 | else mpd).state_dict(),
179 | 'msd': (msd.module if h.num_gpus > 1
180 | else msd).state_dict(),
181 | 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
182 | 'epoch': epoch})
183 |
184 | # Tensorboard summary logging
185 | if steps % a.summary_interval == 0:
186 | sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
187 | sw.add_scalar("training/mel_spec_error", mel_error, steps)
188 |
189 | # Validation
190 | if steps % a.validation_interval == 0: # and steps != 0:
191 | generator.eval()
192 | torch.cuda.empty_cache()
193 | val_err_tot = 0
194 | with torch.no_grad():
195 | for j, batch in enumerate(validation_loader):
196 | x, y, _, y_mel = batch
197 | y_g_hat = generator(x.to(device))
198 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
199 | y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
200 | h.hop_size, h.win_size,
201 | h.fmin, h.fmax_for_loss)
202 | val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
203 |
204 | if j <= 4:
205 | if steps == 0:
206 | sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
207 | sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
208 |
209 | sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
210 | y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
211 | h.sampling_rate, h.hop_size, h.win_size,
212 | h.fmin, h.fmax)
213 | sw.add_figure('generated/y_hat_spec_{}'.format(j),
214 | plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
215 |
216 | val_err = val_err_tot / (j + 1)
217 | sw.add_scalar("validation/mel_spec_error", val_err, steps)
218 |
219 | generator.train()
220 |
221 | steps += 1
222 |
223 | scheduler_g.step()
224 | scheduler_d.step()
225 |
226 | if rank == 0:
227 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
228 |
229 |
230 | def main():
231 | print('Initializing Training Process..')
232 |
233 | parser = argparse.ArgumentParser()
234 |
235 | parser.add_argument('--group_name', default=None)
236 | parser.add_argument('--input_wavs_dir', default='DUMMY1')
237 | parser.add_argument('--input_mels_dir', default='ft_dataset')
238 | parser.add_argument('--input_training_file',
239 | default='AISHELL-3/train.txt')
240 | parser.add_argument('--input_validation_file',
241 | default='AISHELL-3/val.txt')
242 | parser.add_argument('--checkpoint_path', default='cp_hifigan')
243 | parser.add_argument('--config', default='')
244 | parser.add_argument('--training_epochs', default=3100, type=int)
245 | parser.add_argument('--stdout_interval', default=5, type=int)
246 | parser.add_argument('--checkpoint_interval', default=10000, type=int)
247 | parser.add_argument('--summary_interval', default=100, type=int)
248 | parser.add_argument('--validation_interval', default=1000, type=int)
249 | parser.add_argument('--fine_tuning', default=False, type=bool)
250 |
251 | a = parser.parse_args()
252 |
253 | with open(a.config) as f:
254 | data = f.read()
255 |
256 | json_config = json.loads(data)
257 | h = AttrDict(json_config)
258 | build_env(a.config, 'config.json', a.checkpoint_path)
259 |
260 | torch.manual_seed(h.seed)
261 | if torch.cuda.is_available():
262 | torch.cuda.manual_seed(h.seed)
263 | h.num_gpus = torch.cuda.device_count()
264 | h.batch_size = int(h.batch_size / h.num_gpus)
265 | print('Batch size per GPU :', h.batch_size)
266 | else:
267 | pass
268 |
269 | if h.num_gpus > 1:
270 | mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
271 | else:
272 | train(0, a, h)
273 |
274 |
275 | if __name__ == '__main__':
276 | main()
277 |
--------------------------------------------------------------------------------
/train_ada_hifi_gan.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | warnings.simplefilter(action='ignore', category=FutureWarning)
4 | import itertools
5 | import os
6 | import time
7 | import argparse
8 | import json
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.utils.tensorboard import SummaryWriter
12 | from torch.utils.data import DistributedSampler, DataLoader
13 | import torch.multiprocessing as mp
14 | from torch.distributed import init_process_group
15 | from torch.nn.parallel import DistributedDataParallel
16 | from env import AttrDict, build_env
17 | from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
18 | from models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss, \
19 | discriminator_loss, Generator_S, Generator_T, ada_loss
20 | from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
21 |
22 | torch.backends.cudnn.benchmark = True
23 |
24 |
25 | def train(rank, a, h):
26 | if h.num_gpus > 1:
27 | init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
28 | world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
29 |
30 | torch.cuda.manual_seed(h.seed)
31 | device = torch.device('cuda:{:d}'.format(rank))
32 |
33 | generator_s = Generator_S(h).to(device)
34 | generator_t = Generator_T(h).to(device)
35 |
36 | mpd = MultiPeriodDiscriminator().to(device)
37 | msd = MultiScaleDiscriminator().to(device)
38 |
39 | if rank == 0:
40 | print(generator_t)
41 | os.makedirs(a.checkpoint_path, exist_ok=True)
42 | print("checkpoints directory : ", a.checkpoint_path)
43 |
44 | if os.path.isdir(a.checkpoint_path):
45 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
46 | cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
47 |
48 | steps = 0
49 | if cp_g is None or cp_do is None:
50 | state_dict_do = None
51 | last_epoch = -1
52 | else:
53 | state_dict_g = load_checkpoint(cp_g, device)
54 | state_dict_do = load_checkpoint(cp_do, device)
55 | generator_s.load_state_dict(state_dict_g['generator'])
56 | generator_t.load_state_dict(state_dict_g['generator'])
57 | mpd.load_state_dict(state_dict_do['mpd'])
58 | msd.load_state_dict(state_dict_do['msd'])
59 | steps = state_dict_do['steps'] + 1
60 | last_epoch = state_dict_do['epoch']
61 |
62 | if h.num_gpus > 1:
63 | generator_s = DistributedDataParallel(generator_s, device_ids=[rank]).to(device)
64 | generator_t = DistributedDataParallel(generator_t, device_ids=[rank]).to(device)
65 | mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
66 | msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
67 |
68 | optim_g_t = torch.optim.AdamW(generator_t.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
69 | optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
70 | h.learning_rate, betas=[h.adam_b1, h.adam_b2])
71 |
72 | if state_dict_do is not None:
73 | optim_g_t.load_state_dict(state_dict_do['optim_g'])
74 | optim_d.load_state_dict(state_dict_do['optim_d'])
75 |
76 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g_t, gamma=h.lr_decay, last_epoch=last_epoch)
77 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
78 |
79 | training_filelist, validation_filelist = get_dataset_filelist(a)
80 |
81 | trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
82 | h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
83 | shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
84 | fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
85 |
86 | train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
87 |
88 | train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
89 | sampler=train_sampler,
90 | batch_size=h.batch_size,
91 | pin_memory=True,
92 | drop_last=True)
93 |
94 | if rank == 0:
95 | validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
96 | h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
97 | fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
98 | base_mels_path=a.input_mels_dir)
99 | validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
100 | sampler=None,
101 | batch_size=1,
102 | pin_memory=True,
103 | drop_last=True)
104 |
105 | sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
106 |
107 | generator_t.train()
108 | mpd.train()
109 | msd.train()
110 | for epoch in range(max(0, last_epoch), a.training_epochs):
111 | if rank == 0:
112 | start = time.time()
113 | print("Epoch: {}".format(epoch + 1))
114 |
115 | if h.num_gpus > 1:
116 | train_sampler.set_epoch(epoch)
117 |
118 | for i, batch in enumerate(train_loader):
119 | if rank == 0:
120 | start_b = time.time()
121 | x, y, _, y_mel = batch
122 | x = torch.autograd.Variable(x.to(device, non_blocking=True))
123 | y = torch.autograd.Variable(y.to(device, non_blocking=True))
124 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
125 | y = y.unsqueeze(1)
126 | _, fmap_s = generator_s(x)
127 | y_g_hat, fmap_t = generator_t(x)
128 | y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size,
129 | h.win_size,
130 | h.fmin, h.fmax_for_loss)
131 |
132 | optim_d.zero_grad()
133 |
134 | # MPD
135 | y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
136 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
137 |
138 | # MSD
139 | y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
140 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
141 |
142 | loss_disc_all = loss_disc_s + loss_disc_f
143 |
144 | loss_disc_all.backward()
145 | optim_d.step()
146 |
147 | # Generator
148 | optim_g_t.zero_grad()
149 |
150 | # ada loss
151 | a_loss, _ = ada_loss(fmap_s, fmap_t)
152 | a_loss = a_loss * 1000
153 |
154 | # L1 Mel-Spectrogram Loss
155 | loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
156 |
157 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
158 | y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
159 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
160 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
161 | loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
162 | loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
163 | loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel + a_loss
164 |
165 | loss_gen_all.backward()
166 | optim_g_t.step()
167 |
168 | if rank == 0:
169 | # STDOUT logging
170 | if steps % a.stdout_interval == 0:
171 | with torch.no_grad():
172 | mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
173 |
174 | print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
175 | format(steps, loss_gen_all, mel_error, time.time() - start_b))
176 |
177 | # checkpointing
178 | if steps % a.checkpoint_interval == 0 and steps != 0:
179 | checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
180 | save_checkpoint(checkpoint_path,
181 | {'generator': (generator_t.module if h.num_gpus > 1 else generator_t).state_dict()})
182 | checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
183 | save_checkpoint(checkpoint_path,
184 | {'mpd': (mpd.module if h.num_gpus > 1
185 | else mpd).state_dict(),
186 | 'msd': (msd.module if h.num_gpus > 1
187 | else msd).state_dict(),
188 | 'optim_g': optim_g_t.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
189 | 'epoch': epoch})
190 |
191 | # Tensorboard summary logging
192 | if steps % a.summary_interval == 0:
193 | sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
194 | sw.add_scalar("training/mel_spec_error", mel_error, steps)
195 | sw.add_scalar("training/ada_losses", a_loss, steps)
196 |
197 | # Validation
198 | if steps % a.validation_interval == 0: # and steps != 0:
199 | generator_t.eval()
200 | torch.cuda.empty_cache()
201 | val_err_tot = 0
202 | with torch.no_grad():
203 | for j, batch in enumerate(validation_loader):
204 | x, y, _, y_mel = batch
205 | y_g_hat, _ = generator_t(x.to(device))
206 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
207 | y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
208 | h.hop_size, h.win_size,
209 | h.fmin, h.fmax_for_loss)
210 | val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
211 |
212 | if j <= 4:
213 | # if steps == 0:
214 | # sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
215 | # sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
216 | sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
217 | sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
218 |
219 | sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
220 | y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
221 | h.sampling_rate, h.hop_size, h.win_size,
222 | h.fmin, h.fmax)
223 | sw.add_figure('generated/y_hat_spec_{}'.format(j),
224 | plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
225 |
226 | val_err = val_err_tot / (j + 1)
227 | sw.add_scalar("validation/mel_spec_error", val_err, steps)
228 |
229 | generator_t.train()
230 |
231 | steps += 1
232 |
233 | scheduler_g.step()
234 | scheduler_d.step()
235 |
236 | if rank == 0:
237 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
238 |
239 |
240 | def main():
241 | print('Initializing Training Process..')
242 |
243 | parser = argparse.ArgumentParser()
244 |
245 | parser.add_argument('--group_name', default=None)
246 | parser.add_argument('--input_wavs_dir', default='DataBaker')
247 | parser.add_argument('--input_mels_dir', default='ft_dataset')
248 | parser.add_argument('--input_training_file',
249 | default='DataBaker/train.txt')
250 | parser.add_argument('--input_validation_file',
251 | default='DataBaker/val.txt')
252 | parser.add_argument('--checkpoint_path', default='cp_ada_hifigan')
253 | parser.add_argument('--config', default='')
254 | parser.add_argument('--training_epochs', default=31000, type=int)
255 | parser.add_argument('--stdout_interval', default=5, type=int)
256 | parser.add_argument('--checkpoint_interval', default=1000, type=int)
257 | parser.add_argument('--summary_interval', default=10, type=int)
258 | parser.add_argument('--validation_interval', default=100, type=int)
259 | parser.add_argument('--fine_tuning', default=False, type=bool)
260 |
261 | a = parser.parse_args()
262 |
263 | with open(a.config) as f:
264 | data = f.read()
265 |
266 | json_config = json.loads(data)
267 | h = AttrDict(json_config)
268 | build_env(a.config, 'config.json', a.checkpoint_path)
269 |
270 | torch.manual_seed(h.seed)
271 | if torch.cuda.is_available():
272 | torch.cuda.manual_seed(h.seed)
273 | h.num_gpus = torch.cuda.device_count()
274 | h.batch_size = int(h.batch_size / h.num_gpus)
275 | print('Batch size per GPU :', h.batch_size)
276 | else:
277 | pass
278 |
279 | if h.num_gpus > 1:
280 | mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
281 | else:
282 | train(0, a, h)
283 |
284 |
285 | if __name__ == '__main__':
286 | main()
287 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6 | from utils import init_weights, get_padding
7 |
8 | LRELU_SLOPE = 0.1
9 |
10 |
11 | class ResBlock1(torch.nn.Module):
12 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
13 | super(ResBlock1, self).__init__()
14 | self.h = h
15 | self.convs1 = nn.ModuleList([
16 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
17 | padding=get_padding(kernel_size, dilation[0]))),
18 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
19 | padding=get_padding(kernel_size, dilation[1]))),
20 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
21 | padding=get_padding(kernel_size, dilation[2])))
22 | ])
23 | self.convs1.apply(init_weights)
24 |
25 | self.convs2 = nn.ModuleList([
26 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
27 | padding=get_padding(kernel_size, 1))),
28 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
29 | padding=get_padding(kernel_size, 1))),
30 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
31 | padding=get_padding(kernel_size, 1)))
32 | ])
33 | self.convs2.apply(init_weights)
34 |
35 | def forward(self, x):
36 | for c1, c2 in zip(self.convs1, self.convs2):
37 | xt = F.leaky_relu(x, LRELU_SLOPE)
38 | xt = c1(xt)
39 | xt = F.leaky_relu(xt, LRELU_SLOPE)
40 | xt = c2(xt)
41 | x = xt + x
42 | return x
43 |
44 | def remove_weight_norm(self):
45 | for l in self.convs1:
46 | remove_weight_norm(l)
47 | for l in self.convs2:
48 | remove_weight_norm(l)
49 |
50 |
51 | class ResBlock2(torch.nn.Module):
52 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
53 | super(ResBlock2, self).__init__()
54 | self.h = h
55 | self.convs = nn.ModuleList([
56 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
57 | padding=get_padding(kernel_size, dilation[0]))),
58 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
59 | padding=get_padding(kernel_size, dilation[1])))
60 | ])
61 | self.convs.apply(init_weights)
62 |
63 | def forward(self, x):
64 | for c in self.convs:
65 | xt = F.leaky_relu(x, LRELU_SLOPE)
66 | xt = c(xt)
67 | x = xt + x
68 | return x
69 |
70 | def remove_weight_norm(self):
71 | for l in self.convs:
72 | remove_weight_norm(l)
73 |
74 |
75 |
76 | class Generator(torch.nn.Module):
77 | def __init__(self, h):
78 | super(Generator, self).__init__()
79 | self.h = h # config_v1_json
80 | self.num_kernels = len(h.resblock_kernel_sizes) # len([3,7,11])
81 | self.num_upsamples = len(h.upsample_rates) # len([8,8,2,2])
82 | self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) # upsample_initial_channel = 512
83 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 # resblock = 1
84 |
85 | self.ups = nn.ModuleList()
86 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): # upsample_kernel_sizes = [16,16,4,4]
87 | self.ups.append(weight_norm(
88 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
89 | k, u, padding=(k-u)//2)))
90 | self.resblocks = nn.ModuleList()
91 | for i in range(len(self.ups)):
92 | ch = h.upsample_initial_channel//(2**(i+1))
93 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
94 | self.resblocks.append(resblock(h, ch, k, d))
95 |
96 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
97 | self.ups.apply(init_weights)
98 | self.conv_post.apply(init_weights)
99 |
100 | def forward(self, x):
101 | x = self.conv_pre(x)
102 | for i in range(self.num_upsamples):
103 | x = F.leaky_relu(x, LRELU_SLOPE)
104 | x = self.ups[i](x)
105 | xs = None
106 | for j in range(self.num_kernels):
107 | if xs is None:
108 | xs = self.resblocks[i*self.num_kernels+j](x)
109 | else:
110 | xs += self.resblocks[i*self.num_kernels+j](x)
111 | x = xs / self.num_kernels
112 | x = F.leaky_relu(x)
113 | x = self.conv_post(x)
114 | x = torch.tanh(x)
115 | return x
116 |
117 | def remove_weight_norm(self):
118 | print('Removing weight norm...')
119 | for l in self.ups:
120 | remove_weight_norm(l)
121 | for l in self.resblocks:
122 | l.remove_weight_norm()
123 | remove_weight_norm(self.conv_pre)
124 | remove_weight_norm(self.conv_post)
125 |
126 |
127 | class Generator_S(torch.nn.Module):
128 | def __init__(self, h):
129 | super(Generator_S, self).__init__()
130 | self.h = h # config_v1_json
131 | self.num_kernels = len(h.resblock_kernel_sizes) # len([3,7,11])
132 | self.num_upsamples = len(h.upsample_rates) # len([8,8,2,2])
133 | self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) # upsample_initial_channel = 512
134 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 # resblock = 1
135 |
136 | self.ups = nn.ModuleList()
137 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): # upsample_kernel_sizes = [16,16,4,4]
138 | self.ups.append(weight_norm(
139 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
140 | k, u, padding=(k-u)//2)))
141 | self.resblocks = nn.ModuleList()
142 | for i in range(len(self.ups)):
143 | ch = h.upsample_initial_channel//(2**(i+1))
144 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
145 | self.resblocks.append(resblock(h, ch, k, d))
146 |
147 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
148 | self.ups.apply(init_weights)
149 | self.conv_post.apply(init_weights)
150 | for p in self.parameters():
151 | p.requires_grad = False
152 |
153 | def forward(self, x):
154 | fmap = [] #
155 | x = self.conv_pre(x)
156 | fmap.append(x)
157 | for i in range(self.num_upsamples):
158 | x = F.leaky_relu(x, LRELU_SLOPE)
159 | x = self.ups[i](x)
160 | xs = None
161 | for j in range(self.num_kernels):
162 | if xs is None:
163 | xs = self.resblocks[i*self.num_kernels+j](x)
164 | else:
165 | xs += self.resblocks[i*self.num_kernels+j](x)
166 | x = xs / self.num_kernels
167 | fmap.append(x)
168 | x = F.leaky_relu(x)
169 | x = self.conv_post(x)
170 | fmap.append(x)
171 | x = torch.tanh(x)
172 | return x, fmap
173 |
174 | def remove_weight_norm(self):
175 | print('Removing weight norm...')
176 | for l in self.ups:
177 | remove_weight_norm(l)
178 | for l in self.resblocks:
179 | l.remove_weight_norm()
180 | remove_weight_norm(self.conv_pre)
181 | remove_weight_norm(self.conv_post)
182 |
183 |
184 | class Generator_T(torch.nn.Module):
185 | def __init__(self, h):
186 | super(Generator_T, self).__init__()
187 | self.h = h # config_v1_json
188 | self.num_kernels = len(h.resblock_kernel_sizes) # len([3,7,11])
189 | self.num_upsamples = len(h.upsample_rates) # len([8,8,2,2])
190 | self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) # upsample_initial_channel = 512
191 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 # resblock = 1
192 |
193 | self.ups = nn.ModuleList()
194 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): # upsample_kernel_sizes = [16,16,4,4]
195 | self.ups.append(weight_norm(
196 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
197 | k, u, padding=(k-u)//2)))
198 | self.resblocks = nn.ModuleList()
199 | for i in range(len(self.ups)):
200 | ch = h.upsample_initial_channel//(2**(i+1))
201 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
202 | self.resblocks.append(resblock(h, ch, k, d))
203 |
204 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
205 | self.ups.apply(init_weights)
206 | self.conv_post.apply(init_weights)
207 |
208 |
209 | def forward(self, x):
210 | fmap = [] #
211 | x = self.conv_pre(x)
212 | fmap.append(x)
213 | for i in range(self.num_upsamples):
214 | x = F.leaky_relu(x, LRELU_SLOPE)
215 | x = self.ups[i](x)
216 | xs = None
217 | for j in range(self.num_kernels):
218 | if xs is None:
219 | xs = self.resblocks[i*self.num_kernels+j](x)
220 | else:
221 | xs += self.resblocks[i*self.num_kernels+j](x)
222 | x = xs / self.num_kernels
223 | fmap.append(x)
224 | x = F.leaky_relu(x)
225 | x = self.conv_post(x)
226 | fmap.append(x)
227 | x = torch.tanh(x)
228 | return x, fmap
229 |
230 | def remove_weight_norm(self):
231 | print('Removing weight norm...')
232 | for l in self.ups:
233 | remove_weight_norm(l)
234 | for l in self.resblocks:
235 | l.remove_weight_norm()
236 | remove_weight_norm(self.conv_pre)
237 | remove_weight_norm(self.conv_post)
238 |
239 |
240 | class DiscriminatorP(torch.nn.Module):
241 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
242 | super(DiscriminatorP, self).__init__()
243 | self.period = period
244 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
245 | self.convs = nn.ModuleList([
246 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
247 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
248 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
249 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
250 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
251 | ])
252 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
253 |
254 | def forward(self, x):
255 | fmap = []
256 | b, c, t = x.shape
257 |
258 | if t % self.period != 0: # pad first
259 | n_pad = self.period - (t % self.period)
260 | x = F.pad(x, (0, n_pad), "reflect")
261 | t = t + n_pad
262 | x = x.view(b, c, t // self.period, self.period)
263 |
264 | for l in self.convs:
265 | x = l(x)
266 | x = F.leaky_relu(x, LRELU_SLOPE)
267 | fmap.append(x)
268 | x = self.conv_post(x)
269 | fmap.append(x)
270 | x = torch.flatten(x, 1, -1)
271 | return x, fmap
272 |
273 |
274 | class MultiPeriodDiscriminator(torch.nn.Module):
275 | def __init__(self):
276 | super(MultiPeriodDiscriminator, self).__init__()
277 | self.discriminators = nn.ModuleList([
278 | DiscriminatorP(2),
279 | DiscriminatorP(3),
280 | DiscriminatorP(5),
281 | DiscriminatorP(7),
282 | DiscriminatorP(11),
283 | ])
284 |
285 | def forward(self, y, y_hat):
286 | y_d_rs = []
287 | y_d_gs = []
288 | fmap_rs = []
289 | fmap_gs = []
290 | for i, d in enumerate(self.discriminators):
291 | y_d_r, fmap_r = d(y)
292 | y_d_g, fmap_g = d(y_hat)
293 | y_d_rs.append(y_d_r)
294 | fmap_rs.append(fmap_r)
295 | y_d_gs.append(y_d_g)
296 | fmap_gs.append(fmap_g)
297 |
298 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
299 |
300 |
301 | class DiscriminatorS(torch.nn.Module):
302 | def __init__(self, use_spectral_norm=False):
303 | super(DiscriminatorS, self).__init__()
304 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
305 |
306 | self.convs = nn.ModuleList([
307 | norm_f(Conv1d(1, 128, 15, 1, padding=7)),
308 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
309 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
310 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
311 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
312 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
313 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
314 | ])
315 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
316 |
317 | def forward(self, x):
318 | fmap = []
319 | for l in self.convs:
320 | x = l(x)
321 | x = F.leaky_relu(x, LRELU_SLOPE)
322 | fmap.append(x)
323 | x = self.conv_post(x)
324 | fmap.append(x)
325 | x = torch.flatten(x, 1, -1)
326 | return x, fmap
327 |
328 |
329 | class MultiScaleDiscriminator(torch.nn.Module):
330 | def __init__(self):
331 | super(MultiScaleDiscriminator, self).__init__()
332 | self.discriminators = nn.ModuleList([
333 | DiscriminatorS(use_spectral_norm=True),
334 | DiscriminatorS(),
335 | DiscriminatorS(),
336 | ])
337 | self.meanpools = nn.ModuleList([
338 | AvgPool1d(4, 2, padding=2),
339 | AvgPool1d(4, 2, padding=2)
340 | ])
341 |
342 | def forward(self, y, y_hat):
343 | y_d_rs = []
344 | y_d_gs = []
345 | fmap_rs = []
346 | fmap_gs = []
347 | for i, d in enumerate(self.discriminators):
348 | if i != 0:
349 | y = self.meanpools[i-1](y)
350 | y_hat = self.meanpools[i-1](y_hat)
351 | y_d_r, fmap_r = d(y)
352 | y_d_g, fmap_g = d(y_hat)
353 | y_d_rs.append(y_d_r)
354 | fmap_rs.append(fmap_r)
355 | y_d_gs.append(y_d_g)
356 | fmap_gs.append(fmap_g)
357 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
358 |
359 |
360 | def feature_loss(fmap_r, fmap_g):
361 | loss = 0
362 | for dr, dg in zip(fmap_r, fmap_g):
363 | for rl, gl in zip(dr, dg):
364 | loss += torch.mean(torch.abs(rl - gl))
365 |
366 | return loss*2
367 |
368 |
369 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
370 | '''
371 | LS-GAN
372 | Xudong Mao, Qing Li, Haoran Xie, Raymond YK Lau, ZhenWang, and Stephen Paul Smolley.
373 | Least squares generative adversarial networks.
374 | In Proceedings of the IEEE International Conference on Computer Vision, pages 2794–2802, 2017.
375 |
376 | The discriminator is trained to classify ground truth samples to 1, and the samples synthesized from
377 | the generator to 0. The generator is trained to fake the discriminator by updating the sample quality
378 | to be classified to a value almost equal to 1.
379 | '''
380 |
381 | loss = 0
382 | r_losses = []
383 | g_losses = []
384 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
385 | r_loss = torch.mean((1-dr)**2)
386 | g_loss = torch.mean(dg**2)
387 | loss += (r_loss + g_loss)
388 | r_losses.append(r_loss.item())
389 | g_losses.append(g_loss.item())
390 |
391 | return loss, r_losses, g_losses
392 |
393 |
394 | def generator_loss(disc_outputs):
395 | loss = 0
396 | gen_losses = []
397 | for dg in disc_outputs:
398 | l = torch.mean((1-dg)**2)
399 | gen_losses.append(l)
400 | loss += l
401 |
402 | return loss, gen_losses
403 |
404 |
405 | def ada_loss(fmap_s, fmap_t):
406 | loss = 0
407 | ada_losses = []
408 | for s, t in zip(fmap_s, fmap_t):
409 | l_losses = 0
410 | s = s.reshape(s.shape[0], -1)
411 | t = t.reshape(t.shape[0], -1)
412 | cs_s = torch.matmul(s, s.T) / torch.matmul(torch.linalg.norm(s, axis=1).reshape(-1, 1),
413 | torch.linalg.norm(s, axis=1).reshape(1, -1))
414 | cs_t = torch.matmul(t, t.T) / torch.matmul(torch.linalg.norm(t, axis=1).reshape(-1, 1),
415 | torch.linalg.norm(t, axis=1).reshape(1, -1))
416 | # calculate kl-dist
417 | for i, (cs_s_i, cs_t_i) in enumerate(zip(cs_s, cs_t)):
418 | cs_s_i = torch.softmax(torch.cat((cs_s_i[:i], cs_s_i[i+1:]), 0), 0)
419 | cs_t_i = torch.softmax(torch.cat((cs_t_i[:i], cs_t_i[i+1:]), 0), 0)
420 | l_losses += F.kl_div(cs_t_i.log(), cs_s_i, reduction='mean')
421 | l_losses /= cs_s.shape[0]
422 | loss += l_losses
423 | ada_losses.append(l_losses)
424 |
425 | return loss, ada_losses
426 |
427 |
--------------------------------------------------------------------------------