├── hps ├── v4.json ├── v7.json ├── v8.json └── ae.json ├── README.md ├── preprocess ├── make_samples.py ├── make_batches.py └── make_dataset.py ├── main.py ├── test.py ├── utils.py ├── model.py └── solver.py /hps/v4.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr": 0.0002, 3 | "alpha": 0.0001, 4 | "beta": 0.01, 5 | "lambda_": 10, 6 | "ns": 0.01, 7 | "dp": 0.3, 8 | "max_grad_norm": 5, 9 | "max_step": 5, 10 | "seg_len": 128, 11 | "D_iterations": 4, 12 | "batch_size": 32, 13 | "pretrain_iterations": 20000, 14 | "iterations": 100000 15 | } -------------------------------------------------------------------------------- /hps/v7.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr": 0.0002, 3 | "alpha": 0.0001, 4 | "beta": 0.001, 5 | "lambda_": 10, 6 | "ns": 0.01, 7 | "dp": 0.3, 8 | "max_grad_norm": 5, 9 | "max_step": 5, 10 | "seg_len": 128, 11 | "D_iterations": 4, 12 | "batch_size": 32, 13 | "scheduled_iterations": 20000, 14 | "iterations": 100000 15 | } 16 | -------------------------------------------------------------------------------- /hps/v8.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr": 0.0002, 3 | "alpha": 0.0001, 4 | "beta": 0.0001, 5 | "lambda_": 10, 6 | "ns": 0.01, 7 | "dp": 0.0, 8 | "max_grad_norm": 5, 9 | "max_step": 5, 10 | "seg_len": 128, 11 | "D_iterations": 4, 12 | "batch_size": 32, 13 | "scheduled_iterations": 300000, 14 | "iterations": 500000 15 | } 16 | -------------------------------------------------------------------------------- /hps/ae.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr": 0.0002, 3 | "alpha": 0.0, 4 | "beta": 0.0, 5 | "lambda_": 10, 6 | "ns": 0.01, 7 | "dp": 0.0, 8 | "max_grad_norm": 5, 9 | "max_step": 5, 10 | "seg_len": 128, 11 | "n_latent_steps": 0, 12 | "n_patch_steps": 0, 13 | "batch_size": 32, 14 | "scheduled_iterations": 20000, 15 | "iterations": 100000 16 | } 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # voice conversion via disentangle context representation 2 | 3 | ## TODO 4 | - Calculate similarity in latent space 5 | - Patch GAN with condition 6 | - Patch GAN with classification 7 | - Patch GAN with random flip the label? 8 | - i-vector extraction 9 | - Patch GAN with i-vector prediction 10 | - add classifier to embedding 11 | - add postnet and mel-scale 12 | - 口音轉換 13 | - instance norm 14 | 15 | amir dibmohamadi 16 | 39912340048010 17 | olom tahghighat 18 | eslah ghalat emlaie 19 | -------------------------------------------------------------------------------- /preprocess/make_samples.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | from utils import Sampler 4 | import h5py 5 | import numpy as np 6 | import json 7 | 8 | max_step=5 9 | seg_len=32 10 | mel_band=80 11 | lin_band=513 12 | n_samples=2000000 13 | 14 | if __name__ == '__main__': 15 | if len(sys.argv) < 3: 16 | print('usage: python3 make_samples.py [in_h5py_path] [out_json_path]') 17 | exit(0) 18 | sampler = Sampler(sys.argv[1], max_step=max_step, seg_len=seg_len) 19 | samples = [sampler.sample()._asdict() for _ in range(n_samples)] 20 | with open(sys.argv[2], 'w') as f_json: 21 | json.dump(samples, f_json, indent=4, separators=(',', ': ')) 22 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | from torch.autograd import Variable 4 | import numpy as np 5 | import pickle 6 | from utils import Hps 7 | from utils import DataLoader 8 | from utils import Logger 9 | from utils import myDataset 10 | from solver import Solver 11 | import argparse 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--train', default=True, action='store_true') 16 | parser.add_argument('--test', default=False, action='store_true') 17 | parser.add_argument('--load_model', default=False, action='store_true') 18 | parser.add_argument('-flag', default='train') 19 | parser.add_argument('-hps_path', default='./hps/v7.json') 20 | parser.add_argument('-load_model_path', default='/storage/model/voice_conversion/' 21 | 'pretrain_model.pkl-19999') 22 | parser.add_argument('-dataset_path', default='/storage/raw_feature/voice_conversion/vctk/vctk.h5') 23 | parser.add_argument('-index_path', default='/storage/raw_feature/voice_conversion/vctk/128_513_2000k.json') 24 | parser.add_argument('-output_model_path', default='/storage/model/voice_conversion/model.pkl') 25 | args = parser.parse_args() 26 | hps = Hps() 27 | hps.load(args.hps_path) 28 | hps_tuple = hps.get_tuple() 29 | dataset = myDataset(args.dataset_path, 30 | args.index_path, 31 | seg_len=hps_tuple.seg_len) 32 | data_loader = DataLoader(dataset) 33 | 34 | solver = Solver(hps_tuple, data_loader) 35 | if args.load_model: 36 | solver.load_model(args.load_model_path) 37 | if args.train: 38 | solver.train(args.output_model_path, args.flag) 39 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | from torch.autograd import Variable 4 | import numpy as np 5 | import pickle 6 | from utils import Hps 7 | from utils import DataLoader 8 | from utils import Logger 9 | from utils import myDataset 10 | from solver import Solver 11 | from preprocess.tacotron.utils import spectrogram2wav 12 | from scipy.io.wavfile import write 13 | 14 | if __name__ == '__main__': 15 | hps = Hps() 16 | hps.load('./hps/v15.json') 17 | hps_tuple = hps.get_tuple() 18 | dataset = myDataset('/storage/raw_feature/voice_conversion/vctk/vctk.h5',\ 19 | '/storage/raw_feature/voice_conversion/vctk/128_513_2000k.json') 20 | data_loader = DataLoader(dataset) 21 | solver = Solver(hps_tuple, data_loader) 22 | solver.load_model('/storage/model/voice_conversion/model_v15_res.pkl') 23 | spec = np.loadtxt('./preprocess/test_code/lin.npy') 24 | spec2 = np.loadtxt('./preprocess/test_code/lin2.npy') 25 | spec_expand = np.expand_dims(spec, axis=0) 26 | spec_tensor = torch.from_numpy(spec_expand) 27 | spec_tensor = spec_tensor.type(torch.FloatTensor) 28 | spec2_expand = np.expand_dims(spec2, axis=0) 29 | spec2_tensor = torch.from_numpy(spec2_expand) 30 | spec2_tensor = spec2_tensor.type(torch.FloatTensor) 31 | c1 = Variable(torch.from_numpy(np.array([0]))).cuda() 32 | c2 = Variable(torch.from_numpy(np.array([4]))).cuda() 33 | result1 = solver.test_step(spec_tensor, c1) 34 | result1 = result1.squeeze(axis=0).transpose((1, 0)) 35 | result2 = solver.test_step(spec2_tensor, c2) 36 | result2 = result2.squeeze(axis=0).transpose((1, 0)) 37 | result3 = solver.test_step(spec2_tensor, c1) 38 | result3 = result3.squeeze(axis=0).transpose((1, 0)) 39 | result4 = solver.test_step(spec_tensor, c2) 40 | result4 = result4.squeeze(axis=0).transpose((1, 0)) 41 | wav_data = spectrogram2wav(spec) 42 | write('output.wav', rate=16000, data=wav_data) 43 | wav_data = spectrogram2wav(spec2) 44 | write('output2.wav', rate=16000, data=wav_data) 45 | wav_data = spectrogram2wav(result1) 46 | write('output3.wav', rate=16000, data=wav_data) 47 | wav_data = spectrogram2wav(result2) 48 | write('output4.wav', rate=16000, data=wav_data) 49 | wav_data = spectrogram2wav(result3) 50 | write('output5.wav', rate=16000, data=wav_data) 51 | wav_data = spectrogram2wav(result4) 52 | write('output6.wav', rate=16000, data=wav_data) 53 | 54 | 55 | -------------------------------------------------------------------------------- /preprocess/make_batches.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | from utils import Sampler 4 | import h5py 5 | import numpy as np 6 | 7 | max_step=5 8 | seg_len=128 9 | mel_band=80 10 | lin_band=1025 11 | batch_size=16 12 | n_batches=100000 13 | 14 | if __name__ == '__main__': 15 | if len(sys.argv) < 3: 16 | print('usage: python3 make_batches.py [in_h5py_path] [out_h5py_path]') 17 | exit(0) 18 | sampler = Sampler(sys.argv[1], max_step=max_step, seg_len=seg_len) 19 | with h5py.File(sys.argv[2], 'w') as f_h5: 20 | for i in range(n_batches): 21 | samples = { 22 | 'X_i_t':{ 23 | 'mel':np.empty(shape=(batch_size, seg_len, mel_band), dtype=np.float32), 24 | #'lin':np.empty(shape=(batch_size, seg_len, lin_band), dtype=np.float32) 25 | }, 26 | 'X_i_tk':{ 27 | 'mel':np.empty(shape=(batch_size, seg_len, mel_band), dtype=np.float32), 28 | #'lin':np.empty(shape=(batch_size, seg_len, lin_band), dtype=np.float32) 29 | }, 30 | 'X_i_tk_prime':{ 31 | 'mel':np.empty(shape=(batch_size, seg_len, mel_band), dtype=np.float32), 32 | #'lin':np.empty(shape=(batch_size, seg_len, lin_band), dtype=np.float32) 33 | }, 34 | 'X_j':{ 35 | 'mel':np.empty(shape=(batch_size, seg_len, mel_band), dtype=np.float32), 36 | #'lin':np.empty(shape=(batch_size, seg_len, lin_band), dtype=np.float32) 37 | }, 38 | } 39 | for j in range(batch_size): 40 | sample = sampler.sample() 41 | samples['X_i_t']['mel'][j,:] = sample[0] 42 | #samples['X_i_t']['lin'][j,:] = sample[1] 43 | samples['X_i_tk']['mel'][j,:] = sample[2] 44 | #samples['X_i_tk']['lin'][j,:] = sample[3] 45 | samples['X_i_tk_prime']['mel'][j,:] = sample[4] 46 | #samples['X_i_tk_prime']['lin'][j,:] = sample[5] 47 | samples['X_j']['mel'][j,:] = sample[6] 48 | #samples['X_j']['lin'][j,:] = sample[7] 49 | 50 | for data_name in samples: 51 | for data_type in samples[data_name]: 52 | data = samples[data_name][data_type] 53 | f_h5.create_dataset( 54 | '{}/{}/{}'.format(i, data_name, data_type), 55 | data=data, 56 | dtype=np.float32, 57 | ) 58 | if i % 5 == 0: 59 | print('process [{}/{}]'.format(i, n_batches)) 60 | 61 | -------------------------------------------------------------------------------- /preprocess/make_dataset.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import sys 4 | import os 5 | import glob 6 | import re 7 | from collections import defaultdict 8 | from tacotron.utils import get_spectrograms 9 | 10 | '''DEPRECATE 11 | def sort_key(x): 12 | sub = x.split('/')[-1] 13 | l = re.split('_|-', sub.strip('.npy')) 14 | if len(l[-1]) == 1: 15 | l[-1] = '0{}'.format(l[-1]) 16 | return ''.join(l) 17 | 18 | if __name__ == '__main__': 19 | if len(sys.argv) < 3: 20 | print('usage: python3 make_dataset.py [h5py_path] [numpy_dir]') 21 | exit(0) 22 | h5py_path = sys.argv[1] 23 | np_dir = sys.argv[2] 24 | 25 | with h5py.File(h5py_path, 'w') as f_h5: 26 | for dataset, dir_name in zip(['train', 'test'], ['train-clean-100', 'train-clean-100_infer']): 27 | utt = [] 28 | # tuple: (speaker, chapter, utt) 29 | prev_utt = None 30 | np_sub_dir = os.path.join(np_dir, dir_name) 31 | for i, filename in enumerate(sorted(glob.glob(os.path.join(np_sub_dir, '*')), key=sort_key)): 32 | speaker_id, chapter_id, other = filename.split('/')[-1].split('-') 33 | utt_id, seg_id = other[:-4].split('_') 34 | current_utt = (speaker_id, chapter_id, utt_id) 35 | d = np.load(filename) 36 | utt.append(d) 37 | if prev_utt != current_utt and i != 0: 38 | print('dump {}'.format(prev_utt)) 39 | data=np.array(utt, dtype=np.float32) 40 | print(data.shape) 41 | grp = f_h5.create_dataset( 42 | '{}/{}/{}-{}'.format(dataset, speaker_id, chapter_id, utt_id), 43 | data=data, 44 | dtype=np.float32, 45 | ) 46 | utt = [] 47 | prev_utt = current_utt 48 | # last utt 49 | if len(utt) > 0: 50 | print('dump {}'.format(prev_utt)) 51 | data=np.array(utt, dtype=np.float32) 52 | print(data.shape) 53 | grp = f_h5.create_dataset( 54 | '{}/{}/{}-{}'.format(dataset, speaker_id, chapter_id, utt_id), 55 | data=data, 56 | dtype=np.float32, 57 | ) 58 | 59 | ''' 60 | 61 | root_dir='/storage/LibriSpeech/LibriSpeech/train-clean-100' 62 | 63 | if __name__ == '__main__': 64 | if len(sys.argv) < 2: 65 | print('usage: python3 make_dataset.py [h5py_path]') 66 | exit(0) 67 | h5py_path=sys.argv[1] 68 | filename_groups = defaultdict(lambda : []) 69 | with h5py.File(h5py_path, 'w') as f_h5: 70 | grps = [f_h5.create_group('train'), f_h5.create_group('test')] 71 | filenames = sorted(glob.glob(os.path.join(root_dir, '*/*/*.flac'))) 72 | for filename in filenames: 73 | # divide into groups 74 | speaker_id, chapter_id, segment_id = filename.strip().split('/')[-1].strip('.flac').split('-') 75 | filename_groups[speaker_id].append(filename) 76 | for speaker_id, filenames in filename_groups.items(): 77 | print('processing {}'.format(speaker_id)) 78 | for filename in filenames[:-1]: 79 | print(filename) 80 | speaker_id, chapter_id, segment_id = filename.strip().split('/')[-1].strip('.flac').split('-') 81 | mel_spec, lin_spec = get_spectrograms(filename) 82 | grps[0].create_dataset('{}/{}-{}/mel'.format(speaker_id, chapter_id, segment_id), \ 83 | data=mel_spec, dtype=np.float32) 84 | grps[0].create_dataset('{}/{}-{}/lin'.format(speaker_id, chapter_id, segment_id),\ 85 | data=lin_spec, dtype=np.float32) 86 | # the last segment put into testset 87 | filename = filenames[-1] 88 | speaker_id, chapter_id, segment_id = filename.strip().split('/')[-1].strip('.flac').split('-') 89 | mel_spec, lin_spec = get_spectrograms(filename) 90 | grps[1].create_dataset('{}/{}-{}/mel'.format(speaker_id, chapter_id, segment_id), \ 91 | data=mel_spec, dtype=np.float32) 92 | grps[1].create_dataset('{}/{}-{}/lin'.format(speaker_id, chapter_id, segment_id), \ 93 | data=lin_spec, dtype=np.float32) 94 | 95 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import h5py 3 | import pickle 4 | import os 5 | from collections import defaultdict 6 | from collections import namedtuple 7 | import numpy as np 8 | import math 9 | import argparse 10 | import random 11 | import time 12 | import torch 13 | from torch.utils import data 14 | from tensorboardX import SummaryWriter 15 | from torch.autograd import Variable 16 | class Hps(object): 17 | def __init__(self): 18 | self.hps = namedtuple('hps', [ 19 | 'lr', 20 | 'alpha_dis', 21 | 'alpha_enc', 22 | 'beta_dis', 23 | 'beta_dec', 24 | 'beta_clf', 25 | 'lambda_', 26 | 'ns', 27 | 'dp', 28 | 'max_grad_norm', 29 | 'max_step', 30 | 'seg_len', 31 | 'emb_size', 32 | 'n_latent_steps', 33 | 'n_patch_steps', 34 | 'batch_size', 35 | 'lat_sched_iters', 36 | 'patch_start_iter', 37 | 'iters', 38 | ] 39 | ) 40 | default = \ 41 | [1e-4, 1e-2, 1e-4, 1e-3, 1e-4, 1e-4, 10, 0.01, 0.0, 5, 5, 128, 128, 5, 5, 32, 50000, 50000, 60000] 42 | self._hps = self.hps._make(default) 43 | 44 | def get_tuple(self): 45 | return self._hps 46 | 47 | def load(self, path): 48 | with open(path, 'r') as f_json: 49 | hps_dict = json.load(f_json) 50 | self._hps = self.hps(**hps_dict) 51 | 52 | def dump(self, path): 53 | with open(path, 'w') as f_json: 54 | json.dump(self._hps._asdict(), f_json, indent=4, separators=(',', ': ')) 55 | 56 | class Sampler(object): 57 | def __init__( 58 | self, 59 | h5_path='/storage/raw_feature/voice_conversion/vctk/vctk.h5', 60 | speaker_info_path='/storage/raw_feature/voice_conversion/vctk/speaker-info.txt', 61 | utt_len_path='/storage/raw_feature/voice_conversion/vctk/vctk_length.txt', 62 | max_step=5, 63 | seg_len=64, 64 | n_speaker=8, 65 | ): 66 | self.f_h5 = h5py.File(h5_path, 'r') 67 | self.max_step = max_step 68 | self.seg_len = seg_len 69 | #self.read_sex_file(speaker_sex_path) 70 | self.read_vctk_speaker_file(speaker_info_path) 71 | self.utt2len = self.read_utt_len_file(utt_len_path) 72 | self.speakers = list(self.f_h5['train'].keys()) 73 | self.n_speaker = n_speaker 74 | self.speaker_used = self.female_ids[:n_speaker // 2] + self.male_ids[:n_speaker // 2] 75 | self.speaker2utts = {speaker:list(self.f_h5['train/{}'.format(speaker)].keys()) \ 76 | for speaker in self.speakers} 77 | # remove too short utterence 78 | self.rm_too_short_utt() 79 | self.indexer = namedtuple('index', ['speaker_i', 'speaker_j', \ 80 | 'i0', 'i1', 'j', 't', 't_k', 't_prime', 't_j']) 81 | 82 | def read_utt_len_file(self, utt_len_path): 83 | with open(utt_len_path, 'r') as f: 84 | # header 85 | f.readline() 86 | # speaker, utt, length 87 | lines = [tuple(line.strip().split()) for line in f.readlines()] 88 | mapping = {(speaker, utt_id): int(length) for speaker, utt_id, length in lines} 89 | return mapping 90 | 91 | def rm_too_short_utt(self, limit=None): 92 | if not limit: 93 | limit = self.seg_len * 2 94 | for (speaker, utt_id), length in self.utt2len.items(): 95 | if length < limit: 96 | self.speaker2utts[speaker].remove(utt_id) 97 | 98 | def read_vctk_speaker_file(self, speaker_info_path): 99 | self.female_ids, self.male_ids = [], [] 100 | with open(speaker_info_path, 'r') as f: 101 | lines = f.readlines() 102 | infos = [line.strip().split() for line in lines[1:]] 103 | for info in infos: 104 | if info[2] == 'F': 105 | self.female_ids.append(info[0]) 106 | else: 107 | self.male_ids.append(info[0]) 108 | 109 | def read_libre_sex_file(self, speaker_sex_path): 110 | with open(speaker_sex_path, 'r') as f: 111 | # Female 112 | f.readline() 113 | self.female_ids = f.readline().strip().split() 114 | # Male 115 | f.readline() 116 | self.male_ids = f.readline().strip().split() 117 | 118 | def sample_utt(self, speaker_id, n_samples=1): 119 | # sample an utterence 120 | utt_ids = random.sample(self.speaker2utts[speaker_id], n_samples) 121 | lengths = [self.f_h5[f'train/{speaker_id}/{utt_id}/mel'].shape[0] for utt_id in utt_ids] 122 | return [(utt_id, length) for utt_id, length in zip(utt_ids, lengths)] 123 | 124 | def rand(self, l): 125 | rand_idx = random.randint(0, len(l) - 1) 126 | return l[rand_idx] 127 | 128 | def sample(self): 129 | seg_len = self.seg_len 130 | max_step = self.max_step 131 | # sample two speakers 132 | speakerA_idx, speakerB_idx = random.sample(range(len(self.speaker_used)), 2) 133 | speakerA, speakerB = self.speaker_used[speakerA_idx], self.speaker_used[speakerB_idx] 134 | (A_utt_id_0, A_len_0), (A_utt_id_1, A_len_1) = self.sample_utt(speakerA, 2) 135 | (B_utt_id, B_len), = self.sample_utt(speakerB, 1) 136 | # sample t and t^k 137 | t = random.randint(0, A_len_0 - 2 * seg_len) 138 | t_k = random.randint(t + seg_len, min(A_len_0 - seg_len, t + max_step * seg_len)) 139 | t_prime = random.randint(0, A_len_1 - seg_len) 140 | # sample a segment from speakerB 141 | t_j = random.randint(0, B_len - seg_len) 142 | index_tuple = self.indexer(speaker_i=speakerA_idx, speaker_j=speakerB_idx,\ 143 | i0=f'{speakerA}/{A_utt_id_0}', i1=f'{speakerA}/{A_utt_id_1}',\ 144 | j=f'{speakerB}/{B_utt_id}', t=t, t_k=t_k, t_prime=t_prime, t_j=t_j) 145 | return index_tuple 146 | 147 | #class DataLoader(object): 148 | # def __init__(self, h5py_path, batch_size=16): 149 | # self.f_h5 = h5py.File(h5py_path) 150 | # self.keys = list(self.f_h5.keys()) 151 | # self.index = 0 152 | # self.batch_size = batch_size 153 | # 154 | # def __iter__(self): 155 | # return self 156 | # 157 | # def __next__(self): 158 | # if self.index >= len(self.keys): 159 | # self.index = 0 160 | # key = self.keys[self.index] 161 | # data = (self.f_h5['{}/X_i_t/mel'.format(key)][0:self.batch_size], 162 | # self.f_h5['{}/X_i_tk/mel'.format(key)][0:self.batch_size], 163 | # self.f_h5['{}/X_i_tk_prime/mel'.format(key)][0:self.batch_size], 164 | # self.f_h5['{}/X_j/mel'.format(key)][0:self.batch_size]) 165 | # self.index += 1 166 | # return data 167 | 168 | class DataLoader(object): 169 | def __init__(self, dataset, batch_size=16): 170 | self.dataset = dataset 171 | self.n_elements = len(self.dataset[0]) 172 | self.batch_size = batch_size 173 | self.index = 0 174 | 175 | def __iter__(self): 176 | return self 177 | 178 | def __next__(self): 179 | samples = [self.dataset[self.index + i] for i in range(self.batch_size)] 180 | batch = [[s for s in sample] for sample in zip(*samples)] 181 | batch_tensor = [torch.from_numpy(np.array(data)) for data in batch] 182 | 183 | if self.index + 2 * self.batch_size >= len(self.dataset): 184 | self.index = 0 185 | else: 186 | self.index += self.batch_size 187 | return tuple(batch_tensor) 188 | 189 | class myDataset(data.Dataset): 190 | def __init__(self, h5_path, index_path, seg_len=64): 191 | self.h5 = h5py.File(h5_path, 'r') 192 | with open(index_path) as f_index: 193 | self.indexes = json.load(f_index) 194 | self.indexer = namedtuple('index', ['speaker_i', 'speaker_j', \ 195 | 'i0', 'i1', 'j', 't', 't_k', 't_prime', 't_j']) 196 | self.seg_len = seg_len 197 | 198 | def __getitem__(self, i): 199 | index = self.indexes[i] 200 | index = self.indexer(**index) 201 | speaker_i, speaker_j = index.speaker_i, index.speaker_j 202 | i0, i1, j = index.i0, index.i1, index.j 203 | t, t_k, t_prime, t_j = index.t, index.t_k, index.t_prime, index.t_j 204 | seg_len = self.seg_len 205 | data = [speaker_i, speaker_j] 206 | data.append(self.h5[f'train/{i0}/lin'][t:t+seg_len]) 207 | data.append(self.h5[f'train/{i0}/lin'][t_k:t_k+seg_len]) 208 | data.append(self.h5[f'train/{i1}/lin'][t_prime:t_prime+seg_len]) 209 | data.append(self.h5[f'train/{j}/lin'][t_j:t_j+seg_len]) 210 | return tuple(data) 211 | 212 | def __len__(self): 213 | return len(self.indexes) 214 | 215 | class Logger(object): 216 | def __init__(self, log_dir='./log'): 217 | self.writer = SummaryWriter(log_dir) 218 | 219 | def scalar_summary(self, tag, value, step): 220 | self.writer.add_scalar(tag, value, step) 221 | 222 | if __name__ == '__main__': 223 | hps = Hps() 224 | hps.dump('./hps/v18.json') 225 | dataset = myDataset('/home_local/jjery2243542/voice_conversion/datasets/vctk/vctk.h5',\ 226 | '/home_local/jjery2243542/voice_conversion/datasets/vctk/128_513_2000k.json') 227 | data_loader = DataLoader(dataset) 228 | for i, batch in enumerate(data_loader): 229 | print(torch.max(batch[2])) 230 | #sampler = Sampler() 231 | #for i in range(100): 232 | # print(sampler.sample()) 233 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | def pad_layer(inp, layer, is_2d=False): 8 | if type(layer.kernel_size) == tuple: 9 | kernel_size = layer.kernel_size[0] 10 | else: 11 | kernel_size = layer.kernel_size 12 | if not is_2d: 13 | if kernel_size % 2 == 0: 14 | pad = (kernel_size//2, kernel_size//2 - 1) 15 | else: 16 | pad = (kernel_size//2, kernel_size//2) 17 | else: 18 | if kernel_size % 2 == 0: 19 | pad = (kernel_size//2, kernel_size//2 - 1, kernel_size//2, kernel_size//2 - 1) 20 | else: 21 | pad = (kernel_size//2, kernel_size//2, kernel_size//2, kernel_size//2) 22 | # padding 23 | inp = F.pad(inp, 24 | pad=pad, 25 | mode='reflect') 26 | out = layer(inp) 27 | return out 28 | 29 | def upsample(x, scale_factor=2): 30 | x_up = F.upsample(x, scale_factor=2, mode='nearest') 31 | return x_up 32 | 33 | def GLU(inp, layer, res=True): 34 | kernel_size = layer.kernel_size[0] 35 | channels = layer.out_channels // 2 36 | # padding 37 | out = F.pad(inp.unsqueeze(dim=3), pad=(0, 0, kernel_size//2, kernel_size//2), mode='constant', value=0.) 38 | out = out.squeeze(dim=3) 39 | out = layer(out) 40 | # gated 41 | A = out[:, :channels, :] 42 | B = F.sigmoid(out[:, channels:, :]) 43 | if res: 44 | H = A * B + inp 45 | else: 46 | H = A * B 47 | return H 48 | 49 | def highway(inp, layers, gates, act): 50 | # permute 51 | batch_size = inp.size(0) 52 | seq_len = inp.size(2) 53 | inp_permuted = inp.permute(0, 2, 1) 54 | # merge dim 55 | out_expand = inp_permuted.contiguous().view(batch_size*seq_len, inp_permuted.size(2)) 56 | for l, g in zip(layers, gates): 57 | H = l(out_expand) 58 | H = act(H) 59 | T = g(out_expand) 60 | T = F.sigmoid(T) 61 | out_expand = H * T + out_expand * (1. - T) 62 | out_permuted = out_expand.view(batch_size, seq_len, out_expand.size(1)) 63 | out = out_permuted.permute(0, 2, 1) 64 | return out 65 | 66 | def RNN(inp, layer): 67 | inp_permuted = inp.permute(2, 0, 1) 68 | state_mul = (int(layer.bidirectional) + 1) * layer.num_layers 69 | zero_state = Variable(torch.zeros(state_mul, inp.size(0), layer.hidden_size)) 70 | zero_state = zero_state.cuda() if torch.cuda.is_available() else zero_state 71 | out_permuted, _ = layer(inp_permuted, zero_state) 72 | out_rnn = out_permuted.permute(1, 2, 0) 73 | return out_rnn 74 | 75 | def linear(inp, layer): 76 | batch_size = inp.size(0) 77 | hidden_dim = inp.size(1) 78 | seq_len = inp.size(2) 79 | inp_permuted = inp.permute(0, 2, 1) 80 | inp_expand = inp_permuted.contiguous().view(batch_size*seq_len, hidden_dim) 81 | out_expand = layer(inp_expand) 82 | out_permuted = out_expand.view(batch_size, seq_len, out_expand.size(1)) 83 | out = out_permuted.permute(0, 2, 1) 84 | return out 85 | 86 | def append_emb(inp, layer, expand_size, output): 87 | emb = layer(inp) 88 | emb = emb.unsqueeze(dim=2) 89 | emb_expand = emb.expand(emb.size(0), emb.size(1), expand_size) 90 | output = torch.cat([output, emb_expand], dim=1) 91 | return output 92 | 93 | class PatchDiscriminator(nn.Module): 94 | def __init__(self, c_in=513, n_class=8, ns=0.2, dp=0.3): 95 | super(PatchDiscriminator, self).__init__() 96 | self.ns = ns 97 | self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=2) 98 | self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2) 99 | self.conv3 = nn.Conv2d(128, 256, kernel_size=5, stride=2) 100 | self.conv4 = nn.Conv2d(256, 512, kernel_size=5, stride=2) 101 | self.conv5 = nn.Conv2d(512, 1, kernel_size=8) 102 | self.conv_classify = nn.Conv2d(512, n_class, kernel_size=(33, 8)) 103 | self.drop1 = nn.Dropout(p=dp) 104 | self.drop2 = nn.Dropout(p=dp) 105 | self.drop3 = nn.Dropout(p=dp) 106 | self.drop4 = nn.Dropout(p=dp) 107 | 108 | def forward(self, x, classify=False): 109 | x = torch.unsqueeze(x, dim=1) 110 | out = pad_layer(x, self.conv1, is_2d=True) 111 | out = self.drop1(out) 112 | out = F.leaky_relu(out, negative_slope=self.ns) 113 | out = pad_layer(out, self.conv2, is_2d=True) 114 | out = self.drop2(out) 115 | out = F.leaky_relu(out, negative_slope=self.ns) 116 | out = pad_layer(out, self.conv3, is_2d=True) 117 | out = self.drop3(out) 118 | out = F.leaky_relu(out, negative_slope=self.ns) 119 | out = pad_layer(out, self.conv4, is_2d=True) 120 | out = self.drop4(out) 121 | out = F.leaky_relu(out, negative_slope=self.ns) 122 | # GAN output value 123 | val = pad_layer(out, self.conv5, is_2d=True) 124 | val = val.view(val.size(0), -1) 125 | mean_val = torch.mean(val, dim=1) 126 | if classify: 127 | # classify 128 | logits = self.conv_classify(out) 129 | logits = logits.view(logits.size()[0], -1) 130 | logits = F.log_softmax(logits, dim=1) 131 | return mean_val, logits 132 | else: 133 | return mean_val 134 | 135 | class LatentDiscriminator(nn.Module): 136 | def __init__(self, c_in=1024, c_h=256, ns=0.2, dp=0.3): 137 | super(LatentDiscriminator, self).__init__() 138 | self.ns = ns 139 | self.conv1 = nn.Conv1d(c_in, c_h, kernel_size=5) 140 | self.conv2 = nn.Conv1d(c_h, c_h, kernel_size=5) 141 | self.conv3 = nn.Conv1d(c_h, c_h, kernel_size=5) 142 | self.conv4 = nn.Conv1d(c_h, c_h, kernel_size=5) 143 | self.conv5 = nn.Conv1d(c_h, 1, kernel_size=16) 144 | self.drop1 = nn.Dropout(p=dp) 145 | self.drop2 = nn.Dropout(p=dp) 146 | self.drop3 = nn.Dropout(p=dp) 147 | self.drop4 = nn.Dropout(p=dp) 148 | 149 | def forward(self, x): 150 | out1 = pad_layer(x, self.conv1) 151 | out1 = self.drop1(out) 152 | out1 = F.leaky_relu(out, negative_slope=self.ns) 153 | out2 = pad_layer(out1, self.conv2) 154 | out2 = self.drop2(out2) 155 | out2 = F.leaky_relu(out2, negative_slope=self.ns) 156 | out2 = out2 + x 157 | out3 = pad_layer(out2, self.conv3) 158 | out3 = self.drop3(out3) 159 | out3 = F.leaky_relu(out3, negative_slope=self.ns) 160 | out4 = pad_layer(out3, self.conv4) 161 | out4 = self.drop4(out4) 162 | out4 = F.leaky_relu(out4, negative_slope=self.ns) 163 | out = out4 + out2 164 | out = self.conv5(out) 165 | out = out.view(out.size()[0], -1) 166 | mean_value = torch.mean(out, dim=1) 167 | return mean_value 168 | 169 | class CBHG(nn.Module): 170 | def __init__(self, c_in=80, c_out=513): 171 | super(CBHG, self).__init__() 172 | self.conv1s = nn.ModuleList( 173 | [nn.Conv1d(c_in, 128, kernel_size=k) for k in range(1, 9)] 174 | ) 175 | self.bn1s = nn.ModuleList([nn.BatchNorm1d(128) for _ in range(1, 9)]) 176 | self.mp1 = nn.MaxPool1d(kernel_size=2, stride=1) 177 | self.conv2 = nn.Conv1d(len(self.conv1s)*128, 256, kernel_size=3, padding=1) 178 | self.bn2 = nn.BatchNorm1d(256) 179 | self.conv3 = nn.Conv1d(256, 80, kernel_size=3, padding=1) 180 | self.bn3 = nn.BatchNorm1d(80) 181 | # highway network 182 | self.linear1 = nn.Linear(80, 128) 183 | self.layers = nn.ModuleList([nn.Linear(128, 128) for _ in range(4)]) 184 | self.gates = nn.ModuleList([nn.Linear(128, 128) for _ in range(4)]) 185 | self.RNN = nn.GRU(input_size=128, hidden_size=128, num_layers=1, bidirectional=True) 186 | self.linear2 = nn.Linear(256, c_out) 187 | 188 | def forward(self, x): 189 | outs = [] 190 | for l in self.conv1s: 191 | out = pad_layer(x, l) 192 | out = F.relu(out) 193 | outs.append(out) 194 | bn_outs = [] 195 | for out, bn in zip(outs, self.bn1s): 196 | out = bn(out) 197 | bn_outs.append(out) 198 | out = torch.cat(bn_outs, dim=1) 199 | out = pad_layer(out, self.mp1) 200 | out = self.conv2(out) 201 | out = F.relu(out) 202 | out = self.bn2(out) 203 | out = self.conv3(out) 204 | out = self.bn3(out) 205 | out = out + x 206 | out = linear(out, self.linear1) 207 | out = highway(out, self.layers, self.gates, F.relu) 208 | out_rnn = RNN(out, self.RNN) 209 | out = linear(out_rnn, self.linear2) 210 | out = F.sogmoid(out) 211 | return out 212 | 213 | class Decoder(nn.Module): 214 | def __init__(self, c_in=512, c_out=513, c_h=512, c_a=8, emb_size=128, ns=0.2): 215 | super(Decoder, self).__init__() 216 | self.ns = ns 217 | self.conv1 = nn.Conv1d(c_in + emb_size, c_h, kernel_size=5) 218 | self.conv2 = nn.Conv1d(c_h + emb_size, c_h, kernel_size=5) 219 | self.conv3 = nn.Conv1d(c_h + emb_size, c_h, kernel_size=5) 220 | self.conv4 = nn.Conv1d(c_h + emb_size, c_h, kernel_size=5) 221 | self.conv5 = nn.Conv1d(c_h + emb_size, c_h, kernel_size=5) 222 | self.dense1 = nn.Linear(c_h, c_h) 223 | self.dense2 = nn.Linear(c_h, c_h) 224 | self.RNN = nn.GRU(input_size=c_h + emb_size, hidden_size=c_h//2, num_layers=1, bidirectional=True) 225 | self.emb = nn.Embedding(c_a, emb_size) 226 | self.linear = nn.Linear(2*c_h + emb_size, c_out) 227 | 228 | 229 | def forward(self, x, c): 230 | # conv layer 231 | inp = append_emb(c, self.emb, x.size(2), x) 232 | inp = upsample(inp) 233 | out1 = pad_layer(inp, self.conv1) 234 | out1 = F.leaky_relu(out1, negative_slope=self.ns) 235 | out2 = append_emb(c, self.emb, out1.size(2), out1) 236 | out2 = upsample(out2) 237 | out2 = pad_layer(out2, self.conv2) 238 | out2 = F.leaky_relu(out2, negative_slope=self.ns) 239 | out3 = append_emb(c, self.emb, out2.size(2), out2) 240 | out3 = upsample(out3) 241 | out3 = pad_layer(out3, self.conv3) 242 | out3 = F.leaky_relu(out3, negative_slope=self.ns) 243 | out4 = append_emb(c, self.emb, out3.size(2), out3) 244 | out4 = pad_layer(out4, self.conv4) 245 | out4 = F.leaky_relu(out4, negative_slope=self.ns) 246 | out5 = append_emb(c, self.emb, out4.size(2), out4) 247 | out5 = pad_layer(out5, self.conv5) 248 | out5 = F.leaky_relu(out5, negative_slope=self.ns) 249 | out = out5 + out3 250 | # dense layer 251 | out_dense1 = linear(out, self.dense1) 252 | out_dense1 = F.leaky_relu(out_dense1, negative_slope=self.ns) 253 | out_dense2 = linear(out_dense1, self.dense2) 254 | out_dense2 = F.leaky_relu(out_dense2, negative_slope=self.ns) 255 | out = out_dense2 + out 256 | out = append_emb(c, self.emb, out.size(2), out) 257 | # rnn layer 258 | out_rnn = RNN(out, self.RNN) 259 | out = torch.cat([out, out_rnn], dim=1) 260 | out = linear(out, self.linear) 261 | return out 262 | 263 | class Encoder(nn.Module): 264 | def __init__(self, c_in=513, c_h1=128, c_h2=512, c_h3=256, ns=0.2): 265 | super(Encoder, self).__init__() 266 | self.ns = ns 267 | self.conv1s = nn.ModuleList( 268 | [nn.Conv1d(c_in, c_h1, kernel_size=k) for k in range(1, 16)] 269 | ) 270 | self.conv2 = nn.Conv1d(len(self.conv1s)*c_h1 + c_in, c_h2, kernel_size=3) 271 | self.conv3 = nn.Conv1d(c_h2, c_h2, kernel_size=5, stride=2) 272 | self.conv4 = nn.Conv1d(c_h2, c_h2, kernel_size=5, stride=2) 273 | self.conv5 = nn.Conv1d(c_h2, c_h2, kernel_size=5, stride=2) 274 | self.dense1 = nn.Linear(c_h2, c_h2) 275 | self.dense2 = nn.Linear(c_h2, c_h2) 276 | self.dense3 = nn.Linear(c_h2, c_h2) 277 | self.dense4 = nn.Linear(c_h2, c_h2) 278 | self.RNN = nn.GRU(input_size=c_h2, hidden_size=c_h3, num_layers=2, bidirectional=True) 279 | self.linear = nn.Linear(c_h2 + 2*c_h3, c_h2) 280 | 281 | def forward(self, x): 282 | outs = [] 283 | for l in self.conv1s: 284 | out = pad_layer(x, l) 285 | outs.append(out) 286 | out = torch.cat(outs + [x], dim=1) 287 | out = F.leaky_relu(out, negative_slope=self.ns) 288 | out = pad_layer(out, self.conv2) 289 | out = F.leaky_relu(out, negative_slope=self.ns) 290 | out = pad_layer(out, self.conv3) 291 | out = F.leaky_relu(out, negative_slope=self.ns) 292 | out = pad_layer(out, self.conv4) 293 | out = F.leaky_relu(out, negative_slope=self.ns) 294 | out = pad_layer(out, self.conv5) 295 | out = F.leaky_relu(out, negative_slope=self.ns) 296 | out_dense1 = linear(out, self.dense1) 297 | out_dense1 = F.leaky_relu(out_dense1, negative_slope=self.ns) 298 | out_dense2 = linear(out_dense1, self.dense2) 299 | out_dense2 = F.leaky_relu(out_dense2, negative_slope=self.ns) 300 | out_dense2 = out_dense2 + out 301 | out_dense3 = linear(out_dense2, self.dense3) 302 | out_dense3 = F.leaky_relu(out_dense3, negative_slope=self.ns) 303 | out_dense4 = linear(out_dense3, self.dense4) 304 | out_dense4 = F.leaky_relu(out_dense4, negative_slope=self.ns) 305 | out = out_dense4 + out_dense2 306 | out_rnn = RNN(out, self.RNN) 307 | out = torch.cat([out, out_rnn], dim=1) 308 | out = linear(out, self.linear) 309 | return out 310 | # more layer 311 | if __name__ == '__main__': 312 | E1, E2 = Encoder(513).cuda(), Encoder(513).cuda() 313 | D = Decoder().cuda() 314 | C = LatentDiscriminator().cuda() 315 | P = PatchDiscriminator().cuda() 316 | cbhg = CBHG().cuda() 317 | inp = Variable(torch.randn(16, 513, 128)).cuda() 318 | e1 = E1(inp) 319 | e2 = E2(inp) 320 | c = Variable(torch.from_numpy(np.random.randint(8, size=(16)))).cuda() 321 | d = D(e1, c) 322 | #print(d.size()) 323 | p1, p2 = P(d, classify=True) 324 | #print(p1.size(), p2.size()) 325 | c = C(torch.cat([e2,e2],dim=1)) 326 | print(c.size()) 327 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | from torch.autograd import Variable 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import pickle 8 | from utils import myDataset 9 | from model import Encoder 10 | from model import Decoder 11 | from model import LatentDiscriminator 12 | from model import PatchDiscriminator 13 | from model import CBHG 14 | import os 15 | from utils import Hps 16 | from utils import Logger 17 | from utils import DataLoader 18 | #from preprocess.tacotron import utils 19 | 20 | def cal_mean_grad(net): 21 | grad = Variable(torch.FloatTensor([0])).cuda() 22 | for i, p in enumerate(net.parameters()): 23 | grad += torch.mean(p.grad) 24 | return grad.data[0] / (i + 1) 25 | 26 | def calculate_gradients_penalty(netD, real_data, fake_data): 27 | alpha = torch.rand(real_data.size(0)) 28 | alpha = alpha.view(real_data.size(0), 1, 1) 29 | alpha = alpha.cuda() if torch.cuda.is_available() else alpha 30 | alpha = Variable(alpha) 31 | interpolates = alpha * real_data + (1 - alpha) * fake_data 32 | 33 | disc_interpolates = netD(interpolates) 34 | 35 | gradients = torch.autograd.grad( 36 | outputs=torch.mean(disc_interpolates), 37 | inputs=interpolates, 38 | create_graph=True, retain_graph=True, only_inputs=True)[0] 39 | gradients_penalty = (1. - torch.sqrt(1e-12 + torch.sum(gradients.view(gradients.size(0), -1)**2, dim=1))) ** 2 40 | gradients_penalty = torch.mean(gradients_penalty) 41 | return gradients_penalty 42 | 43 | class Solver(object): 44 | def __init__(self, hps, data_loader, log_dir='./log/'): 45 | self.hps = hps 46 | self.data_loader = data_loader 47 | self.model_kept = [] 48 | self.max_keep = 10 49 | self.build_model() 50 | self.logger = Logger(log_dir) 51 | 52 | def build_model(self): 53 | ns = self.hps.ns 54 | emb_size = self.hps.emb_size 55 | self.Encoder = Encoder(ns=ns) 56 | self.Decoder = Decoder(ns=ns, emb_size=emb_size) 57 | self.LatentDiscriminator = LatentDiscriminator(ns=ns) 58 | self.PatchDiscriminator = PatchDiscriminator(ns=ns) 59 | if torch.cuda.is_available(): 60 | self.Encoder.cuda() 61 | self.Decoder.cuda() 62 | self.LatentDiscriminator.cuda() 63 | self.PatchDiscriminator.cuda() 64 | betas = (0.5, 0.9) 65 | params = list(self.Encoder.parameters()) + list(self.Decoder.parameters()) 66 | self.ae_opt = optim.Adam(params, lr=self.hps.lr, betas=betas) 67 | self.decoder_opt = optim.Adam(self.Decoder.parameters(), lr=self.hps.lr, betas=betas) 68 | self.lat_opt = optim.Adam(self.LatentDiscriminator.parameters(), lr=self.hps.lr, betas=betas) 69 | self.patch_opt = optim.Adam(self.PatchDiscriminator.parameters(), lr=self.hps.lr, betas=betas) 70 | 71 | def to_var(self, x, requires_grad=True): 72 | x = Variable(x, requires_grad=requires_grad) 73 | return x.cuda() if torch.cuda.is_available() else x 74 | 75 | def save_model(self, model_path, iteration, enc_only=True): 76 | if not enc_only: 77 | all_model = { 78 | 'encoder': self.Encoder.state_dict(), 79 | 'decoder': self.Decoder.state_dict(), 80 | 'latent_discriminator': self.LatentDiscriminator.state_dict(), 81 | 'patch_discriminator': self.PatchDiscriminator.state_dict(), 82 | } 83 | else: 84 | all_model = { 85 | 'encoder': self.Encoder.state_dict(), 86 | 'decoder': self.Decoder.state_dict(), 87 | } 88 | new_model_path = '{}-{}'.format(model_path, iteration) 89 | with open(new_model_path, 'wb') as f_out: 90 | torch.save(all_model, f_out) 91 | self.model_kept.append(new_model_path) 92 | 93 | if len(self.model_kept) >= self.max_keep: 94 | os.remove(self.model_kept[0]) 95 | self.model_kept.pop(0) 96 | 97 | def reset_grad(self, net_list): 98 | for net in net_list: 99 | net.zero_grad() 100 | 101 | def load_model(self, model_path, enc_only=True): 102 | print('load model from {}'.format(model_path)) 103 | with open(model_path, 'rb') as f_in: 104 | all_model = torch.load(f_in) 105 | self.Encoder.load_state_dict(all_model['encoder']) 106 | self.Decoder.load_state_dict(all_model['decoder']) 107 | if not enc_only: 108 | self.LatentDiscriminator.load_state_dict(all_model['latent_discriminator']) 109 | self.PatchDiscriminator.load_state_dict(all_model['patch_discriminator']) 110 | 111 | def grad_clip(self, net_list): 112 | max_grad_norm = self.hps.max_grad_norm 113 | for net in net_list: 114 | torch.nn.utils.clip_grad_norm(net.parameters(), max_grad_norm) 115 | 116 | def test_step(self, x, c): 117 | x = self.to_var(x).permute(0, 2, 1) 118 | enc = self.Encoder(x) 119 | x_tilde = self.Decoder(enc, c) 120 | return x_tilde.data.cpu().numpy() 121 | 122 | def permute_data(self, data): 123 | C = [self.to_var(c, requires_grad=False) for c in data[:2]] 124 | X = [self.to_var(x).permute(0, 2, 1) for x in data[2:]] 125 | return C, X 126 | 127 | def sample_c(self, size): 128 | c_sample = Variable( 129 | torch.multinomial(torch.ones(8), num_samples=size, replacement=True), 130 | requires_grad=False) 131 | c_sample = c_sample.cuda() if torch.cuda.is_available() else c_sample 132 | return c_sample 133 | 134 | def cal_acc(self, logits, y_true): 135 | _, ind = torch.max(logits, dim=1) 136 | acc = torch.sum((ind == y_true).type(torch.FloatTensor)) / y_true.size(0) 137 | return acc 138 | 139 | def encode_step(self, *args): 140 | enc_list = [] 141 | for x in args: 142 | enc = self.Encoder(x) 143 | enc_list.append(enc) 144 | return tuple(enc_list) 145 | 146 | def decode_step(self, enc, c): 147 | x_tilde = self.Decoder(enc, c) 148 | return x_tilde 149 | 150 | def latent_discriminate_step(self, enc_i_t, enc_i_tk, enc_i_prime, enc_j, cal_gp=True): 151 | same_pair = torch.cat([enc_i_t, enc_i_tk], dim=1) 152 | diff_pair = torch.cat([enc_i_prime, enc_j], dim=1) 153 | same_val = self.LatentDiscriminator(same_pair) 154 | diff_val = self.LatentDiscriminator(diff_pair) 155 | w_dis = torch.mean(same_val - diff_val) 156 | if cal_gp: 157 | gp = calculate_gradients_penalty(self.LatentDiscriminator, same_pair, diff_pair) 158 | return w_dis, gp 159 | else: 160 | return (w_dis,) 161 | 162 | def patch_discriminate_step(self, x, x_tilde, cal_gp=True): 163 | # w-distance 164 | D_real, real_logits = self.PatchDiscriminator(x, classify=True) 165 | D_fake, fake_logits = self.PatchDiscriminator(x_tilde, classify=True) 166 | w_dis = torch.mean(D_real - D_fake) 167 | if cal_gp: 168 | gp = calculate_gradients_penalty(self.PatchDiscriminator, x, x_tilde) 169 | return w_dis, real_logits, fake_logits, gp 170 | else: 171 | return w_dis, real_logits, fake_logits 172 | # backup 173 | #def classify(): 174 | # # aux clssify loss 175 | # criterion = nn.NLLLoss() 176 | # c_loss = criterion(real_logits, c) + criterion(fake_logits, c_sample) 177 | # real_acc = self.cal_acc(real_logits, c) 178 | # fake_acc = self.cal_acc(fake_logits, c_sample) 179 | 180 | def train(self, model_path, flag='train'): 181 | # load hyperparams 182 | hps = self.hps 183 | for iteration in range(hps.iters): 184 | # calculate current alpha 185 | if iteration + 1 < hps.lat_sched_iters: 186 | current_alpha = hps.alpha_enc * (iteration + 1) / hps.lat_sched_iters 187 | for step in range(hps.n_latent_steps): 188 | #===================== Train latent discriminator =====================# 189 | data = next(self.data_loader) 190 | (c_i, c_j), (x_i_t, x_i_tk, x_i_prime, x_j) = self.permute_data(data) 191 | # encode 192 | enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step(x_i_t, x_i_tk, x_i_prime, x_j) 193 | # latent discriminate 194 | latent_w_dis, latent_gp = self.latent_discriminate_step(enc_i_t, enc_i_tk, enc_i_prime, enc_j) 195 | lat_loss = -hps.alpha_dis * latent_w_dis + hps.lambda_ * latent_gp 196 | self.reset_grad([self.LatentDiscriminator]) 197 | lat_loss.backward() 198 | self.grad_clip([self.LatentDiscriminator]) 199 | self.lat_opt.step() 200 | # print info 201 | info = { 202 | f'{flag}/D_latent_w_dis': latent_w_dis.data[0], 203 | f'{flag}/latent_gp': latent_gp.data[0], 204 | } 205 | slot_value = (step, iteration + 1, hps.iters) + \ 206 | tuple([value for value in info.values()]) 207 | log = 'lat_D-%d:[%06d/%06d], w_dis=%.3f, gp=%.2f' 208 | print(log % slot_value) 209 | for tag, value in info.items(): 210 | self.logger.scalar_summary(tag, value, iteration) 211 | # two stage training 212 | if iteration >= hps.patch_start_iter: 213 | for step in range(hps.n_patch_steps): 214 | #===================== Train patch discriminator =====================# 215 | data = next(self.data_loader) 216 | (c_i, _), (x_i_t, _, _, _) = self.permute_data(data) 217 | # encode 218 | enc_i_t, = self.encode_step(x_i_t) 219 | c_sample = self.sample_c(x_i_t.size(0)) 220 | x_tilde = self.decode_step(enc_i_t, c_i) 221 | # Aux classify loss 222 | patch_w_dis, real_logits, fake_logits, patch_gp = \ 223 | self.patch_discriminate_step(x_i_t, x_tilde, c_i, c_sample) 224 | patch_loss = -hps.beta_dis * patch_w_dis + hps.lambda_ * patch_gp + hps.beta_clf * c_loss 225 | self.reset_grad([self.PatchDiscriminator]) 226 | patch_loss.backward() 227 | self.grad_clip([self.PatchDiscriminator]) 228 | self.patch_opt.step() 229 | # print info 230 | info = { 231 | f'{flag}/D_patch_w_dis': patch_w_dis.data[0], 232 | f'{flag}/patch_gp': patch_gp.data[0], 233 | f'{flag}/c_loss': c_loss.data[0], 234 | f'{flag}/real_acc': real_acc, 235 | f'{flag}/fake_acc': fake_acc, 236 | } 237 | slot_value = (step, iteration + 1, hps.iters) + \ 238 | tuple([value for value in info.values()]) 239 | log = 'patch_D-%d:[%06d/%06d], w_dis=%.3f, gp=%.2f, c_loss=%.3f, real_acc=%.2f, fake_acc=%.2f' 240 | print(log % slot_value) 241 | for tag, value in info.items(): 242 | self.logger.scalar_summary(tag, value, iteration) 243 | #===================== Train G =====================# 244 | data = next(self.data_loader) 245 | (c_i, c_j), (x_i_t, x_i_tk, x_i_prime, x_j) = self.permute_data(data) 246 | # encode 247 | enc_i_t, enc_i_tk, enc_i_prime, enc_j = self.encode_step(x_i_t, x_i_tk, x_i_prime, x_j) 248 | # decode 249 | x_tilde = self.decode_step(enc_i_t, c_i) 250 | loss_rec = torch.mean(torch.abs(x_tilde - x_i_t)) 251 | # latent discriminate 252 | latent_w_dis, = self.latent_discriminate_step( 253 | enc_i_t, enc_i_tk, enc_i_prime, enc_j, cal_gp=False) 254 | ae_loss = loss_rec + current_alpha * latent_w_dis 255 | self.reset_grad([self.Encoder, self.Decoder]) 256 | retain_graph = True if hps.n_patch_steps > 0 else False 257 | ae_loss.backward(retain_graph=retain_graph) 258 | self.grad_clip([self.Encoder, self.Decoder]) 259 | self.ae_opt.step() 260 | info = { 261 | f'{flag}/loss_rec': loss_rec.data[0], 262 | f'{flag}/G_latent_w_dis': latent_w_dis.data[0], 263 | f'{flag}/alpha': current_alpha, 264 | } 265 | slot_value = (iteration+1, hps.iters) + tuple([value for value in info.values()]) 266 | log = 'G:[%06d/%06d], loss_rec=%.2f, latent_w_dis=%.2f, alpha=%.2e' 267 | print(log % slot_value) 268 | for tag, value in info.items(): 269 | self.logger.scalar_summary(tag, value, iteration + 1) 270 | # patch discriminate 271 | if hps.n_patch_steps > 0 and iteration >= hps.patch_start_iter: 272 | c_sample = self.sample_c(x_i_t.size(0)) 273 | x_tilde = self.decode_step(enc_i_t, c_sample) 274 | patch_w_dis, real_logits, fake_logits = \ 275 | self.patch_discriminate_step(x_i_t, x_tilde, cal_gp=False) 276 | patch_loss = hps.beta_dec * patch_w_dis + hps.beta_clf * c_loss 277 | self.reset_grad([self.Decoder]) 278 | patch_loss.backward() 279 | self.grad_clip([self.Decoder]) 280 | self.decoder_opt.step() 281 | info = { 282 | f'{flag}/G_patch_w_dis': patch_w_dis.data[0], 283 | f'{flag}/c_loss': c_loss.data[0], 284 | f'{flag}/real_acc': real_acc, 285 | f'{flag}/fake_acc': fake_acc, 286 | } 287 | slot_value = (iteration+1, hps.iters) + tuple([value for value in info.values()]) 288 | log = 'G:[%06d/%06d]: patch_w_dis=%.2f, c_loss=%.2f, real_acc=%.2f, fake_acc=%.2f' 289 | print(log % slot_value) 290 | for tag, value in info.items(): 291 | self.logger.scalar_summary(tag, value, iteration + 1) 292 | if iteration % 1000 == 0 or iteration + 1 == hps.iters: 293 | self.save_model(model_path, iteration) 294 | 295 | if __name__ == '__main__': 296 | hps = Hps() 297 | hps.load('./hps/v7.json') 298 | hps_tuple = hps.get_tuple() 299 | dataset = myDataset('/storage/raw_feature/voice_conversion/vctk/vctk.h5',\ 300 | '/storage/raw_feature/voice_conversion/vctk/64_513_2000k.json') 301 | data_loader = DataLoader(dataset) 302 | 303 | solver = Solver(hps_tuple, data_loader) 304 | --------------------------------------------------------------------------------