├── images ├── VC2.png ├── input.png └── output.png ├── report └── 2018102249_홍준식_프로젝트최종보고서.hwp ├── code ├── read_hdf5.py ├── wav_to_hdf5.py ├── convert.py ├── CycleGAN_VC2.py └── train.py └── README.md /images/VC2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jshong0907/SingingVoiceConversion/HEAD/images/VC2.png -------------------------------------------------------------------------------- /images/input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jshong0907/SingingVoiceConversion/HEAD/images/input.png -------------------------------------------------------------------------------- /images/output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jshong0907/SingingVoiceConversion/HEAD/images/output.png -------------------------------------------------------------------------------- /report/2018102249_홍준식_프로젝트최종보고서.hwp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jshong0907/SingingVoiceConversion/HEAD/report/2018102249_홍준식_프로젝트최종보고서.hwp -------------------------------------------------------------------------------- /code/read_hdf5.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | from torch.utils import data 4 | 5 | 6 | class ReadHDF5(data.Dataset): # dataset을 읽기 위한 클래스 7 | def __init__(self, file_name, key): # 생성자 8 | self.hf = h5py.File(file_name, 'r') 9 | self.data = self.hf.get(key).value 10 | self.len = self.data.shape[0] 11 | 12 | def __getitem__(self, index): 13 | return self.data[index] 14 | 15 | def __len__(self): 16 | return self.len -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Singing Voice Conversion 2 | 3 | ### Goal 4 | 5 | > Change Singer A's singing voice to Singer B's singing voice 6 | 7 | ### Result 8 | 9 | > ![그림입니다. 원본 그림의 크기: 가로 1002pixel, 세로 466pixel](./images/input.png) 10 | 11 | > ![그림입니다. 원본 그림의 이름: CLP0000263c0004.bmp 원본 그림의 크기: 가로 1002pixel, 세로 466pixel](./images/output.png) 12 | 13 | ### Paper 14 | 15 | > CycleGAN-VC2: Improved CycleGAN-based Non-parallel Voice Conversion 16 | > 17 | > ![그림입니다. 원본 그림의 크기: 가로 602pixel, 세로 68pixel](./images/VC2.png) 18 | 19 | -------------------------------------------------------------------------------- /code/wav_to_hdf5.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import pydub 4 | import os 5 | from pydub import AudioSegment 6 | import vggish_input 7 | import librosa 8 | 9 | def load_and_save(file_dir="./", save_name="default.h5", key_name="default", crop_time=4): 10 | hf = h5py.File(file_dir + save_name, "w") 11 | song_array = [] # song은 time, wave 12 | for file in os.listdir(file_dir): 13 | if file.endswith(".mp3"): 14 | song = pydub.AudioSegment.from_mp3(file_dir + "/" + file) 15 | 16 | if song.channels == 2: 17 | song = np.array(song.get_array_of_samples()) 18 | song = song.reshape((-1, 2)) 19 | 20 | for i in range(2): 21 | sec3 = [] 22 | for sec in range(1, len(song)): 23 | sec3.append(song[sec][i]) 24 | 25 | if sec % (44100 * crop_time) == 0: 26 | song_array.append(sec3) 27 | sec3 = [] 28 | 29 | 30 | else: 31 | song = np.array(song.get_array_of_samples()) 32 | song = song.reshape((-1, 1)) 33 | 34 | sec3 = [] 35 | for sec in range(1, len(song)): 36 | sec3.append(song[sec]) 37 | 38 | if sec % (44100 * crop_time) == 0: 39 | song_array.append(sec3) 40 | sec3 = [] 41 | 42 | 43 | if file.endswith(".wav"): 44 | song = pydub.AudioSegment.from_mp3(file_dir + "/" + file) 45 | song = np.array(song.get_array_of_samples()) 46 | song = song.reshape((-1, 1)) 47 | 48 | sec3 = [] 49 | for sec in range(1, len(song)): 50 | sec3.append(song[sec]) 51 | 52 | if sec % (44100 * crop_time) == 0: 53 | song_array.append(sec3) 54 | sec3 = [] 55 | 56 | song_array = np.asarray(song_array).reshape((-1, 44100 * crop_time, 1)).tolist() 57 | print(np.asarray(song_array).shape) 58 | 59 | print(np.asarray(song_array[0]).shape) 60 | print(np.asarray(song_array[1]).shape) 61 | hf.create_dataset(key_name, data=song_array) 62 | hf.close() 63 | 64 | 65 | def mp3_to_wav(file_dir="", save_file_dir=""): 66 | filenames = os.listdir(file_dir) 67 | for filename in filenames: 68 | fname, ext = os.path.splitext(filename) # 확장자 제거 69 | dst = fname + '.wav' 70 | 71 | sound = AudioSegment.from_mp3(file_dir + "/" + filename) 72 | sound.export(save_file_dir + "/" + dst, format="wav") 73 | 74 | 75 | def wav_to_hdf5(file_dir="", save_file=""): 76 | filenames = os.listdir(file_dir) 77 | hf = h5py.File(save_file, "w") 78 | song_array = [] 79 | for filename in filenames: 80 | y, sr = librosa.load(file_dir + "/" + filename, sr=16000) 81 | sec = int(sr * 3) 82 | for i in range(int(y.shape[0]/sec)): 83 | batch = y[i*sec:(i+1)*sec] 84 | mfccs = librosa.feature.mfcc(np.asarray(batch), n_mfcc=64) 85 | song_array.append(mfccs) 86 | song_array = np.asarray(song_array) 87 | print(song_array.shape) 88 | hf.create_dataset("song", data=song_array) 89 | hf.close() 90 | 91 | 92 | wav_to_hdf5("./data/IU/wav/", "IU.h5") 93 | #mp3_to_wav("./data/LJB/mp3", "./data/LJB/wav") -------------------------------------------------------------------------------- /code/convert.py: -------------------------------------------------------------------------------- 1 | from CycleGAN_VC2 import Generator, Discriminator 2 | import torch 3 | import librosa 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import librosa.display 7 | 8 | netG_A2B = Generator(num_features=94) 9 | netG_A2B2 = Generator(num_features=94) 10 | netG_A2B3 = Generator(num_features=94) 11 | #netG_B2A = Generator(num_features=94) 12 | netD_A = Discriminator(64, 94) 13 | #netD_B = Discriminator(64, 94) 14 | 15 | if torch.cuda.is_available(): 16 | netG_A2B.cuda() 17 | netG_A2B2.cuda() 18 | netG_A2B3.cuda() 19 | #netG_B2A.cuda() 20 | netD_A.cuda() 21 | #netD_B.cuda() 22 | 23 | 24 | netG_A2B.load_state_dict(torch.load('output/netG_IU2LJB.pth')) 25 | netG_A2B2.load_state_dict(torch.load('output/netG_IU2LJB2.pth')) 26 | netG_A2B3.load_state_dict(torch.load('output/netG_IU2LJB3.pth')) 27 | #netG_B2A.load_state_dict(torch.load('output/netG_LJB2IU.pth')) 28 | #netG_A2B.load_state_dict(torch.load('output/netG_IU2LJB.pth')) 29 | #netG_A2B.load_state_dict(torch.load('output/netG_IU2LJB.pth')) 30 | #netG_A2B.load_state_dict(torch.load('output/netG_IU2LJB.pth')) 31 | netD_A.load_state_dict(torch.load('output/netD_IU.pth')) 32 | #netD_B.load_state_dict(torch.load('output/netD_B.pth')) 33 | 34 | file_dir = "data/IU/wav" 35 | filename = "GoodDay.wav" 36 | 37 | y, sr = librosa.load(file_dir + "/" + filename, sr=16000) 38 | #original = [] 39 | result = [] 40 | result2 = [] 41 | result3 = [] 42 | ensemble = [] 43 | for i in range(1, 11): 44 | yt = y[sr*i*3:sr*3*(i+1)] 45 | mfccs = librosa.feature.mfcc(yt, n_mfcc=64) 46 | plt.figure(figsize=(10, 4)) 47 | librosa.display.specshow(mfccs, x_axis='time') 48 | plt.colorbar() 49 | plt.title('input') 50 | plt.tight_layout() 51 | plt.show() 52 | mfcc = mfccs 53 | mfccs = mfccs.reshape((1, 1, mfccs.shape[0], mfccs.shape[1])) 54 | mfccs = torch.from_numpy(mfccs).float().cuda() 55 | output = netG_A2B(mfccs) 56 | output2 = netG_A2B2(mfccs) 57 | output3 = netG_A2B3(mfccs) 58 | print(netD_A(mfccs)) 59 | print(netD_A(output)) 60 | output = output.cpu().detach().numpy() 61 | output2 = output2.cpu().detach().numpy() 62 | output3 = output3.cpu().detach().numpy() 63 | output = output.reshape((output.shape[2], output.shape[3])) 64 | output2 = output2.reshape((output2.shape[2], output2.shape[3])) 65 | output3 = output3.reshape((output3.shape[2], output3.shape[3])) 66 | plt.figure(figsize=(10, 4)) 67 | librosa.display.specshow(output, x_axis='time') 68 | plt.colorbar() 69 | plt.title('output') 70 | plt.tight_layout() 71 | plt.show() 72 | audio = librosa.feature.inverse.mfcc_to_audio(output) 73 | audio2 = librosa.feature.inverse.mfcc_to_audio(output2) 74 | audio3 = librosa.feature.inverse.mfcc_to_audio(output3) 75 | #originals = librosa.feature.inverse.mfcc_to_audio(mfcc) 76 | for t in audio: 77 | result.append(t) 78 | for t in audio2: 79 | result2.append(t) 80 | for t in audio3: 81 | result3.append(t) 82 | for s in range(len(audio)): 83 | ensemble.append((audio[s] + audio2[s] + audio3[s]) / 3) 84 | #for k in originals: 85 | #original.append(k) 86 | print(result) 87 | result = np.asarray(result) 88 | result2 = np.asarray(result2) 89 | result3 = np.asarray(result3) 90 | ensemble = np.asarray(ensemble, dtype=np.float32) 91 | print(type(ensemble[0])) 92 | #original = np.asarray(original) 93 | print(result[0]) 94 | librosa.output.write_wav("test.wav", result, sr=16000) 95 | librosa.output.write_wav("test2.wav", result2, sr=16000) 96 | librosa.output.write_wav("test3.wav", result3, sr=16000) 97 | librosa.output.write_wav("ensemble.wav", ensemble, sr=16000) 98 | #librosa.output.write_wav("original.wav", original, sr=16000) 99 | -------------------------------------------------------------------------------- /code/CycleGAN_VC2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class DownSampleBlock(nn.Module): 7 | def __init__(self, inchannels, outchannels): 8 | super(DownSampleBlock, self).__init__() 9 | self.conv = nn.Conv2d(in_channels=inchannels, out_channels=outchannels, kernel_size=(5, 5), 10 | stride=(2, 2), padding=(2, 2), bias=False) 11 | 12 | def forward(self, x): 13 | x = nn.InstanceNorm2d(x.shape[1])(self.conv(x)) 14 | x = nn.GLU(dim=1)(x) 15 | return x 16 | 17 | class UpSampleBlock(nn.Module): 18 | def __init__(self, inchannels, outchannels): 19 | super(UpSampleBlock, self).__init__() 20 | self.conv = nn.Conv2d(in_channels=inchannels, out_channels=outchannels, kernel_size=(5, 5), 21 | stride=(1, 1), padding=(2, 2), bias=False) 22 | 23 | def forward(self, x): 24 | x = nn.PixelShuffle(2)(nn.InstanceNorm2d(x.shape[1])(self.conv(x))) 25 | x = nn.GLU(dim=1)(x) 26 | return x 27 | 28 | class ResidualBlock(nn.Module): 29 | def __init__(self): 30 | super(ResidualBlock, self).__init__() 31 | self.conv1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(1, 3), 32 | stride=(1, 1), padding=(0, 1), bias=False) 33 | self.conv2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(1, 3), 34 | stride=(1, 1), padding=(0, 1), bias=False) 35 | 36 | def forward(self, x): 37 | res = x 38 | x = self.conv1(x) 39 | x = nn.InstanceNorm2d(x.shape[1])(x) 40 | x = nn.GLU(dim=1)(x) 41 | x = self.conv2(x) 42 | x = nn.InstanceNorm2d(x.shape[1])(x) 43 | return x + res 44 | 45 | 46 | class Generator(nn.Module): 47 | def __init__(self, num_features): 48 | super(Generator, self).__init__() 49 | 50 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=128, kernel_size=(5, 15), 51 | stride=(1, 1), padding=(2, 7), bias=False) 52 | 53 | self.downsample_block1 = DownSampleBlock(inchannels=64, outchannels=256) 54 | self.downsample_block2 = DownSampleBlock(inchannels=128, outchannels=512) 55 | 56 | self.conv2 = nn.Conv2d(in_channels=int(int((num_features+1)/2 + 1)/2)*256, out_channels=256, kernel_size=(1, 1), 57 | stride=(1, 1), padding=(0, 0), bias=False) 58 | 59 | for i in range(6): 60 | self.add_module("residual_block" + str(i+1), ResidualBlock()) 61 | 62 | self.conv3 = nn.Conv2d(in_channels=256, out_channels=int(int((num_features+1)/2 + 1)/2)*256, kernel_size=(1, 1), 63 | stride=(1, 1), padding=(0, 0), bias=False) 64 | 65 | self.upsample_block1 = UpSampleBlock(inchannels=256, outchannels=1024) 66 | self.upsample_block2 = UpSampleBlock(inchannels=128, outchannels=512) 67 | 68 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=num_features, kernel_size=(5, 15), 69 | stride=(1, 1), padding=(2, 7), bias=False) 70 | 71 | self.output = nn.Conv2d(in_channels=int(int((num_features+1)/2 + 1)/2)*4, out_channels=1, kernel_size=(5, 15), 72 | stride=(1, 1), padding=(2, 7), bias=False) 73 | 74 | 75 | def forward(self, x): 76 | time = x.shape[2] 77 | num_featuers = x.shape[3] 78 | x = self.conv1(x) 79 | x = nn.GLU(dim=1)(x) 80 | for i in range(2): 81 | x = self.__getattr__("downsample_block" + str(i+1))(x) 82 | x = x.reshape((x.shape[0], -1, int(int((time+1)/2 + 1)/2), 1)) 83 | x = self.conv2(x) 84 | x = nn.InstanceNorm2d(x.shape[1])(x) 85 | 86 | for i in range(6): 87 | x = self.__getattr__("residual_block" + str(i+1))(x) 88 | 89 | x = self.conv3(x) 90 | x = nn.InstanceNorm2d(x.shape[1])(x) 91 | x = x.reshape(x.shape[0], -1, int(int((time+1)/2 + 1)/2), int(int((num_featuers+1)/2 + 1)/2)) 92 | 93 | for i in range(2): 94 | x = self.__getattr__("upsample_block" + str(i+1))(x) 95 | 96 | x = self.conv4(x) 97 | 98 | x = x.reshape(x.shape[0], x.shape[3], x.shape[2], x.shape[1]) 99 | 100 | return self.output(x) 101 | 102 | class Discriminator(nn.Module): 103 | def __init__(self, width, height): 104 | super(Discriminator, self).__init__() 105 | 106 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), 107 | stride=(1, 1), padding=(1, 1), bias=False) 108 | 109 | self.downsample_block1 = DownSampleBlock(inchannels=64, outchannels=256) 110 | self.downsample_block2 = DownSampleBlock(inchannels=128, outchannels=512) 111 | self.downsample_block3 = DownSampleBlock(inchannels=256, outchannels=1024) 112 | self.conv2 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=(1, 5), 113 | stride=(1, 1), padding=(0, 2), bias=False) 114 | 115 | self.conv3 = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=(1, 3), 116 | stride=(1, 1), padding=(0, 1), bias=False) 117 | 118 | self.fc = nn.Linear(int((width+7)/8) * int((height+7)/8), 1) 119 | 120 | def forward(self, x): 121 | time = x.shape[2] 122 | num_featuers = x.shape[3] 123 | x = self.conv1(x) 124 | x = nn.GLU(dim=1)(x) 125 | 126 | for i in range(3): 127 | x = self.__getattr__("downsample_block" + str(i+1))(x) 128 | 129 | x = self.conv2(x) 130 | x = nn.InstanceNorm2d(x.shape[1])(x) 131 | x = nn.GLU(dim=1)(x) 132 | x = self.conv3(x) 133 | x = x.view(x.shape[0], -1) 134 | x = self.fc(x) 135 | return nn.Sigmoid()(x) 136 | -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from CycleGAN_VC2 import Generator, Discriminator 4 | import vggish 5 | import vggish_input 6 | from torch.utils.data import Dataset, DataLoader 7 | import h5py 8 | from torch.autograd import Variable 9 | import time 10 | 11 | 12 | singerA_dataset = "./IU.h5" 13 | singerB_dataset = "./LJB.h5" 14 | batchSize = 32 15 | load_model = True 16 | 17 | 18 | if torch.cuda.is_available(): 19 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 20 | 21 | 22 | class SongDataset(Dataset): 23 | def __init__(self, filename): 24 | self.data = h5py.File(filename, 'r') 25 | self.song = self.data.get("song").value 26 | self.x = self.song.reshape((-1, 1, 64, 94)) 27 | self.len = self.x.shape[0] 28 | 29 | def __getitem__(self, index): 30 | return self.x[index] 31 | 32 | def __len__(self): 33 | return self.len 34 | 35 | 36 | train_A = SongDataset(singerA_dataset) 37 | train_B = SongDataset(singerB_dataset) 38 | 39 | train_loader_A = DataLoader(dataset=train_A, batch_size=batchSize, shuffle=True) 40 | train_loader_B = DataLoader(dataset=train_B, batch_size=batchSize, shuffle=True) 41 | 42 | netG_A2B = Generator(num_features=94) 43 | netG_B2A = Generator(num_features=94) 44 | netD_A = Discriminator(64, 94) 45 | netD_B = Discriminator(64, 94) 46 | 47 | if torch.cuda.is_available(): 48 | netG_A2B.cuda() 49 | netG_B2A.cuda() 50 | netD_A.cuda() 51 | netD_B.cuda() 52 | 53 | if load_model: 54 | netG_A2B.load_state_dict(torch.load('output/netG_IU2LJB.pth')) 55 | netG_B2A.load_state_dict(torch.load('output/netG_LJB2IU.pth')) 56 | netD_A.load_state_dict(torch.load('output/netD_IU.pth')) 57 | netD_B.load_state_dict(torch.load('output/netD_LJB.pth')) 58 | 59 | criterion_GAN = torch.nn.BCELoss() 60 | criterion_cycle = torch.nn.L1Loss() 61 | criterion_identity = torch.nn.L1Loss() 62 | 63 | # Optimizers & LR schedulers 64 | optimizer_G = torch.optim.Adam(list(netG_A2B.parameters()) + list(netG_B2A.parameters()), lr=0.001) 65 | optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=0.001) 66 | optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=0.001) 67 | 68 | # Inputs & targets memory allocation 69 | Tensor = torch.cuda.FloatTensor 70 | input_A = Tensor(batchSize, 1, 64, 94) 71 | input_B = Tensor(batchSize, 1, 64, 94) 72 | target_real = Variable(Tensor(batchSize).fill_(1.0), requires_grad=False) 73 | target_fake = Variable(Tensor(batchSize).fill_(0.0), requires_grad=False) 74 | 75 | 76 | ################################### 77 | 78 | ###### Training ###### 79 | for epoch in range(1000): 80 | start = time.time() 81 | for i, (real_A, real_B) in enumerate(zip(train_loader_A, train_loader_B)): 82 | 83 | if real_A.shape[0] != batchSize: 84 | break 85 | if real_B.shape[0] != batchSize: 86 | break 87 | 88 | real_A = real_A.cuda().float() 89 | real_B = real_B.cuda().float() 90 | 91 | ############ Generator ############ 92 | optimizer_G.zero_grad() 93 | 94 | netG_A2B.train() 95 | netG_B2A.train() 96 | netD_A.eval() 97 | netD_B.eval() 98 | 99 | # gan loss 100 | G_A2B = netG_A2B(real_A) 101 | G_B2A = netG_B2A(real_B) 102 | pred_fake_B = netD_B(G_A2B) 103 | pred_fake_A = netD_A(G_B2A) 104 | gan_loss_A = criterion_GAN(pred_fake_A, target_real) 105 | gan_loss_B = criterion_GAN(pred_fake_B, target_real) 106 | 107 | gan_loss = (gan_loss_A + gan_loss_B) * 10 108 | 109 | # identity loss A를 넣으면 A가, B를 넣으면 B가 나오도록 identity를 설정 110 | G_A2A = netG_B2A(real_A) 111 | G_B2B = netG_A2B(real_B) 112 | identity_loss_A = criterion_identity(G_A2A, real_A) 113 | identity_loss_B = criterion_identity(G_B2B, real_B) 114 | 115 | identity_loss = (identity_loss_A + identity_loss_B) * 5 116 | 117 | # cycle loss 118 | G_A2B2A = netG_B2A(G_A2B) 119 | G_B2A2B = netG_A2B(G_B2A) 120 | 121 | cycle_loss_A = criterion_cycle(G_A2B2A, real_A) 122 | cycle_loss_B = criterion_cycle(G_B2A2B, real_B) 123 | 124 | cycle_loss = (cycle_loss_A + cycle_loss_B) * 10 125 | 126 | # cycle loss with discriminator 127 | pred_fake_B2A2B = netD_B(G_B2A2B) 128 | pred_fake_A2B2A = netD_A(G_A2B2A) 129 | 130 | cycle_gan_loss_A = criterion_GAN(pred_fake_A, target_real) 131 | cycle_gan_loss_B = criterion_GAN(pred_fake_B, target_real) 132 | 133 | cycle_gan_loss = (cycle_gan_loss_A + cycle_gan_loss_B) * 10 134 | 135 | # Total loss 136 | loss_G = identity_loss + gan_loss + cycle_loss + cycle_gan_loss 137 | loss_G.backward() 138 | 139 | optimizer_G.step() 140 | ################################### 141 | 142 | ###### Discriminator A ###### 143 | optimizer_D_A.zero_grad() 144 | 145 | netG_A2B.eval() 146 | netG_B2A.eval() 147 | netD_A.train() 148 | netD_B.train() 149 | 150 | # Real loss 151 | pred_real = netD_A(real_A) 152 | loss_D_real = criterion_GAN(pred_real, target_real) 153 | 154 | pred_fake = netD_A(netG_B2A(real_B)) 155 | loss_D_fake = criterion_GAN(pred_fake, target_fake) 156 | 157 | # Total loss 158 | loss_D_A = (loss_D_real + loss_D_fake) * 0.5 159 | loss_D_A.backward() 160 | 161 | optimizer_D_A.step() 162 | ################################### 163 | 164 | ###### Discriminator B ###### 165 | optimizer_D_B.zero_grad() 166 | 167 | # Real loss 168 | pred_real = netD_B(real_B) 169 | loss_D_real = criterion_GAN(pred_real, target_real) 170 | 171 | # Fake loss 172 | pred_fake = netD_B(netG_A2B(real_A)) 173 | loss_D_fake = criterion_GAN(pred_fake, target_fake) 174 | 175 | # Total loss 176 | loss_D_B = (loss_D_real + loss_D_fake) * 0.5 177 | loss_D_B.backward() 178 | 179 | optimizer_D_B.step() 180 | print("time :", time.time() - start) 181 | print("epoch ", epoch+1, " finished") 182 | 183 | # Save models checkpoints 184 | torch.save(netG_A2B.state_dict(), 'output/netG_IU2LJB.pth') 185 | torch.save(netG_B2A.state_dict(), 'output/netG_LJB2IU.pth') 186 | torch.save(netD_A.state_dict(), 'output/netD_IU.pth') 187 | torch.save(netD_B.state_dict(), 'output/netD_LJB.pth') 188 | ################################### --------------------------------------------------------------------------------