├── images ├── apple2orange │ ├── AtoB │ │ ├── 1_input.png │ │ ├── 1_output.png │ │ ├── 1_recon.png │ │ ├── 2_input.png │ │ ├── 2_output.png │ │ ├── 2_recon.png │ │ ├── 3_input.png │ │ ├── 3_output.png │ │ ├── 3_recon.png │ │ ├── 4_input.png │ │ ├── 4_output.png │ │ ├── 4_recon.png │ │ ├── 5_input.png │ │ ├── 5_output.png │ │ └── 5_recon.png │ └── BtoA │ │ ├── 1_input.png │ │ ├── 1_output.png │ │ ├── 1_recon.png │ │ ├── 2_input.png │ │ ├── 2_output.png │ │ ├── 2_recon.png │ │ ├── 3_input.png │ │ ├── 3_output.png │ │ ├── 3_recon.png │ │ ├── 4_input.png │ │ ├── 4_output.png │ │ ├── 4_recon.png │ │ ├── 5_input.png │ │ ├── 5_output.png │ │ └── 5_recon.png └── horse2zebra │ ├── AtoB │ ├── 1_input.png │ ├── 1_output.png │ ├── 1_recon.png │ ├── 2_input.png │ ├── 2_output.png │ ├── 2_recon.png │ ├── 3_input.png │ ├── 3_output.png │ ├── 3_recon.png │ ├── 4_input.png │ ├── 4_output.png │ ├── 4_recon.png │ ├── 5_input.png │ ├── 5_output.png │ └── 5_recon.png │ └── BtoA │ ├── 1_input.png │ ├── 1_output.png │ ├── 1_recon.png │ ├── 2_input.png │ ├── 2_output.png │ ├── 2_recon.png │ ├── 3_input.png │ ├── 3_output.png │ ├── 3_recon.png │ ├── 4_input.png │ ├── 4_output.png │ ├── 4_recon.png │ ├── 5_input.png │ ├── 5_output.png │ └── 5_recon.png ├── network.py ├── README.md ├── util.py └── pytorch_cycleGAN.py /images/apple2orange/AtoB/1_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/1_input.png -------------------------------------------------------------------------------- /images/apple2orange/AtoB/1_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/1_output.png -------------------------------------------------------------------------------- /images/apple2orange/AtoB/1_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/1_recon.png -------------------------------------------------------------------------------- /images/apple2orange/AtoB/2_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/2_input.png -------------------------------------------------------------------------------- /images/apple2orange/AtoB/2_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/2_output.png -------------------------------------------------------------------------------- /images/apple2orange/AtoB/2_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/2_recon.png -------------------------------------------------------------------------------- /images/apple2orange/AtoB/3_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/3_input.png -------------------------------------------------------------------------------- /images/apple2orange/AtoB/3_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/3_output.png -------------------------------------------------------------------------------- /images/apple2orange/AtoB/3_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/3_recon.png -------------------------------------------------------------------------------- /images/apple2orange/AtoB/4_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/4_input.png -------------------------------------------------------------------------------- /images/apple2orange/AtoB/4_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/4_output.png -------------------------------------------------------------------------------- /images/apple2orange/AtoB/4_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/4_recon.png -------------------------------------------------------------------------------- /images/apple2orange/AtoB/5_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/5_input.png -------------------------------------------------------------------------------- /images/apple2orange/AtoB/5_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/5_output.png -------------------------------------------------------------------------------- /images/apple2orange/AtoB/5_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/5_recon.png -------------------------------------------------------------------------------- /images/apple2orange/BtoA/1_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/1_input.png -------------------------------------------------------------------------------- /images/apple2orange/BtoA/1_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/1_output.png -------------------------------------------------------------------------------- /images/apple2orange/BtoA/1_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/1_recon.png -------------------------------------------------------------------------------- /images/apple2orange/BtoA/2_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/2_input.png -------------------------------------------------------------------------------- /images/apple2orange/BtoA/2_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/2_output.png -------------------------------------------------------------------------------- /images/apple2orange/BtoA/2_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/2_recon.png -------------------------------------------------------------------------------- /images/apple2orange/BtoA/3_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/3_input.png -------------------------------------------------------------------------------- /images/apple2orange/BtoA/3_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/3_output.png -------------------------------------------------------------------------------- /images/apple2orange/BtoA/3_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/3_recon.png -------------------------------------------------------------------------------- /images/apple2orange/BtoA/4_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/4_input.png -------------------------------------------------------------------------------- /images/apple2orange/BtoA/4_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/4_output.png -------------------------------------------------------------------------------- /images/apple2orange/BtoA/4_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/4_recon.png -------------------------------------------------------------------------------- /images/apple2orange/BtoA/5_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/5_input.png -------------------------------------------------------------------------------- /images/apple2orange/BtoA/5_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/5_output.png -------------------------------------------------------------------------------- /images/apple2orange/BtoA/5_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/5_recon.png -------------------------------------------------------------------------------- /images/horse2zebra/AtoB/1_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/1_input.png -------------------------------------------------------------------------------- /images/horse2zebra/AtoB/1_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/1_output.png -------------------------------------------------------------------------------- /images/horse2zebra/AtoB/1_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/1_recon.png -------------------------------------------------------------------------------- /images/horse2zebra/AtoB/2_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/2_input.png -------------------------------------------------------------------------------- /images/horse2zebra/AtoB/2_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/2_output.png -------------------------------------------------------------------------------- /images/horse2zebra/AtoB/2_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/2_recon.png -------------------------------------------------------------------------------- /images/horse2zebra/AtoB/3_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/3_input.png -------------------------------------------------------------------------------- /images/horse2zebra/AtoB/3_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/3_output.png -------------------------------------------------------------------------------- /images/horse2zebra/AtoB/3_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/3_recon.png -------------------------------------------------------------------------------- /images/horse2zebra/AtoB/4_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/4_input.png -------------------------------------------------------------------------------- /images/horse2zebra/AtoB/4_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/4_output.png -------------------------------------------------------------------------------- /images/horse2zebra/AtoB/4_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/4_recon.png -------------------------------------------------------------------------------- /images/horse2zebra/AtoB/5_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/5_input.png -------------------------------------------------------------------------------- /images/horse2zebra/AtoB/5_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/5_output.png -------------------------------------------------------------------------------- /images/horse2zebra/AtoB/5_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/5_recon.png -------------------------------------------------------------------------------- /images/horse2zebra/BtoA/1_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/1_input.png -------------------------------------------------------------------------------- /images/horse2zebra/BtoA/1_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/1_output.png -------------------------------------------------------------------------------- /images/horse2zebra/BtoA/1_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/1_recon.png -------------------------------------------------------------------------------- /images/horse2zebra/BtoA/2_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/2_input.png -------------------------------------------------------------------------------- /images/horse2zebra/BtoA/2_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/2_output.png -------------------------------------------------------------------------------- /images/horse2zebra/BtoA/2_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/2_recon.png -------------------------------------------------------------------------------- /images/horse2zebra/BtoA/3_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/3_input.png -------------------------------------------------------------------------------- /images/horse2zebra/BtoA/3_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/3_output.png -------------------------------------------------------------------------------- /images/horse2zebra/BtoA/3_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/3_recon.png -------------------------------------------------------------------------------- /images/horse2zebra/BtoA/4_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/4_input.png -------------------------------------------------------------------------------- /images/horse2zebra/BtoA/4_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/4_output.png -------------------------------------------------------------------------------- /images/horse2zebra/BtoA/4_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/4_recon.png -------------------------------------------------------------------------------- /images/horse2zebra/BtoA/5_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/5_input.png -------------------------------------------------------------------------------- /images/horse2zebra/BtoA/5_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/5_output.png -------------------------------------------------------------------------------- /images/horse2zebra/BtoA/5_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/5_recon.png -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class generator(nn.Module): 6 | # initializers 7 | def __init__(self, input_nc, output_nc, ngf=32, nb=6): 8 | super(generator, self).__init__() 9 | self.input_nc = input_nc 10 | self.output_nc = output_nc 11 | self.ngf = ngf 12 | self.nb = nb 13 | self.conv1 = nn.Conv2d(input_nc, ngf, 7, 1, 0) 14 | self.conv1_norm = nn.InstanceNorm2d(ngf) 15 | self.conv2 = nn.Conv2d(ngf, ngf * 2, 3, 2, 1) 16 | self.conv2_norm = nn.InstanceNorm2d(ngf * 2) 17 | self.conv3 = nn.Conv2d(ngf * 2, ngf * 4, 3, 2, 1) 18 | self.conv3_norm = nn.InstanceNorm2d(ngf * 4) 19 | 20 | self.resnet_blocks = [] 21 | for i in range(nb): 22 | self.resnet_blocks.append(resnet_block(ngf * 4, 3, 1, 1)) 23 | self.resnet_blocks[i].weight_init(0, 0.02) 24 | 25 | self.resnet_blocks = nn.Sequential(*self.resnet_blocks) 26 | 27 | self.deconv1 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, 1) 28 | self.deconv1_norm = nn.InstanceNorm2d(ngf * 2) 29 | self.deconv2 = nn.ConvTranspose2d(ngf * 2, ngf, 3, 2, 1, 1) 30 | self.deconv2_norm = nn.InstanceNorm2d(ngf) 31 | self.deconv3 = nn.Conv2d(ngf, output_nc, 7, 1, 0) 32 | 33 | # weight_init 34 | def weight_init(self, mean, std): 35 | for m in self._modules: 36 | normal_init(self._modules[m], mean, std) 37 | 38 | # forward method 39 | def forward(self, input): 40 | x = F.pad(input, (3, 3, 3, 3), 'reflect') 41 | x = F.relu(self.conv1_norm(self.conv1(x))) 42 | x = F.relu(self.conv2_norm(self.conv2(x))) 43 | x = F.relu(self.conv3_norm(self.conv3(x))) 44 | x = self.resnet_blocks(x) 45 | x = F.relu(self.deconv1_norm(self.deconv1(x))) 46 | x = F.relu(self.deconv2_norm(self.deconv2(x))) 47 | x = F.pad(x, (3, 3, 3, 3), 'reflect') 48 | o = F.tanh(self.deconv3(x)) 49 | 50 | return o 51 | 52 | class discriminator(nn.Module): 53 | # initializers 54 | def __init__(self, input_nc, output_nc, ndf=64): 55 | super(discriminator, self).__init__() 56 | self.input_nc = input_nc 57 | self.output_nc = output_nc 58 | self.ndf = ndf 59 | self.conv1 = nn.Conv2d(input_nc, ndf, 4, 2, 1) 60 | self.conv2 = nn.Conv2d(ndf, ndf * 2, 4, 2, 1) 61 | self.conv2_norm = nn.InstanceNorm2d(ndf * 2) 62 | self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1) 63 | self.conv3_norm = nn.InstanceNorm2d(ndf * 4) 64 | self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, 4, 1, 1) 65 | self.conv4_norm = nn.InstanceNorm2d(ndf * 8) 66 | self.conv5 = nn.Conv2d(ndf * 8, output_nc, 4, 1, 1) 67 | 68 | # weight_init 69 | def weight_init(self, mean, std): 70 | for m in self._modules: 71 | normal_init(self._modules[m], mean, std) 72 | 73 | # forward method 74 | def forward(self, input): 75 | x = F.leaky_relu(self.conv1(input), 0.2) 76 | x = F.leaky_relu(self.conv2_norm(self.conv2(x)), 0.2) 77 | x = F.leaky_relu(self.conv3_norm(self.conv3(x)), 0.2) 78 | x = F.leaky_relu(self.conv4_norm(self.conv4(x)), 0.2) 79 | x = self.conv5(x) 80 | 81 | return x 82 | 83 | # resnet block with reflect padding 84 | class resnet_block(nn.Module): 85 | def __init__(self, channel, kernel, stride, padding): 86 | super(resnet_block, self).__init__() 87 | self.channel = channel 88 | self.kernel = kernel 89 | self.strdie = stride 90 | self.padding = padding 91 | self.conv1 = nn.Conv2d(channel, channel, kernel, stride, 0) 92 | self.conv1_norm = nn.InstanceNorm2d(channel) 93 | self.conv2 = nn.Conv2d(channel, channel, kernel, stride, 0) 94 | self.conv2_norm = nn.InstanceNorm2d(channel) 95 | 96 | # weight_init 97 | def weight_init(self, mean, std): 98 | for m in self._modules: 99 | normal_init(self._modules[m], mean, std) 100 | 101 | def forward(self, input): 102 | x = F.pad(input, (self.padding, self.padding, self.padding, self.padding), 'reflect') 103 | x = F.relu(self.conv1_norm(self.conv1(x))) 104 | x = F.pad(x, (self.padding, self.padding, self.padding, self.padding), 'reflect') 105 | x = self.conv2_norm(self.conv2(x)) 106 | 107 | return input + x 108 | 109 | def normal_init(m, mean, std): 110 | if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): 111 | m.weight.data.normal_(mean, std) 112 | m.bias.data.zero_() 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-CycleGAN 2 | Pytorch implementation of CycleGAN [1]. 3 | 4 | * you can download datasets: https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/ 5 | * you can see more information for network architecture and training details in https://arxiv.org/pdf/1703.10593.pdf 6 | 7 | ## dataset 8 | * apple2orange 9 | * apple training images: 995, orange training images: 1,019, apple test images: 266, orange test images: 248 10 | * horse2zebra 11 | * horse training images: 1,067, zebra training images: 1,334, horse test images: 120, zebra test images: 140 12 | 13 | ## Resutls 14 | ### apple2orange (after 200 epochs) 15 | * apple2orange 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 27 | 28 | 32 | 33 | 37 | 38 | 42 | 43 | 47 |
Input Output Reconstruction
24 | 25 | 26 |
29 | 30 | 31 |
34 | 35 | 36 |
39 | 40 | 41 |
44 | 45 | 46 |
48 | 49 | * orange2apple 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 61 | 62 | 66 | 67 | 71 | 72 | 76 | 77 | 81 |
Input Output Reconstruction
58 | 59 | 60 |
63 | 64 | 65 |
68 | 69 | 70 |
73 | 74 | 75 |
78 | 79 | 80 |
82 | 83 | * Learning Time 84 | * apple2orange - Avg. per epoch: 299.38 sec; Total 200 epochs: 62,225.33 sec 85 | 86 | ### horse2zebra (after 200 epochs) 87 | * horse2zebra 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 99 | 100 | 104 | 105 | 109 | 110 | 114 | 115 | 119 |
Input Output Reconstruction
96 | 97 | 98 |
101 | 102 | 103 |
106 | 107 | 108 |
111 | 112 | 113 |
116 | 117 | 118 |
120 | 121 | * zebra2horse 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 133 | 134 | 138 | 139 | 143 | 144 | 148 | 149 | 153 |
Input Output Reconstruction
130 | 131 | 132 |
135 | 136 | 137 |
140 | 141 | 142 |
145 | 146 | 147 |
150 | 151 | 152 |
154 | 155 | * Learning Time 156 | * horse2zebra - Avg. per epoch: 299.25 sec; Total 200 epochs: 61,221.27 sec 157 | 158 | ## Development Environment 159 | 160 | * Ubuntu 14.04 LTS 161 | * NVIDIA GTX 1080 ti 162 | * cuda 8.0 163 | * Python 2.7.6 164 | * pytorch 0.1.12 165 | * matplotlib 1.3.1 166 | * scipy 0.19.1 167 | 168 | ## Reference 169 | 170 | [1] Zhu, Jun-Yan, et al. "Unpaired image-to-image translation using cycle-consistent adversarial networks." arXiv preprint arXiv:1703.10593 (2017). 171 | 172 | (Full paper: https://arxiv.org/pdf/1703.10593.pdf) 173 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import itertools, imageio, torch, random 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from torchvision import datasets 5 | from scipy.misc import imresize 6 | from torch.autograd import Variable 7 | 8 | def show_result(G, x_, y_, num_epoch, show = False, save = False, path = 'result.png'): 9 | test_images = G(x_) 10 | 11 | size_figure_grid = 3 12 | fig, ax = plt.subplots(x_.size()[0], size_figure_grid, figsize=(5, 5)) 13 | for i, j in itertools.product(range(x_.size()[0]), range(size_figure_grid)): 14 | ax[i, j].get_xaxis().set_visible(False) 15 | ax[i, j].get_yaxis().set_visible(False) 16 | 17 | for i in range(x_.size()[0]): 18 | ax[i, 0].cla() 19 | ax[i, 0].imshow((x_[i].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2) 20 | ax[i, 1].cla() 21 | ax[i, 1].imshow((test_images[i].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2) 22 | ax[i, 2].cla() 23 | ax[i, 2].imshow((y_[i].numpy().transpose(1, 2, 0) + 1) / 2) 24 | 25 | label = 'Epoch {0}'.format(num_epoch) 26 | fig.text(0.5, 0.04, label, ha='center') 27 | 28 | if save: 29 | plt.savefig(path) 30 | 31 | if show: 32 | plt.show() 33 | else: 34 | plt.close() 35 | 36 | def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'): 37 | x = range(len(hist['D_A_losses'])) 38 | 39 | y1 = hist['D_A_losses'] 40 | y2 = hist['D_B_losses'] 41 | y3 = hist['G_A_losses'] 42 | y4 = hist['G_B_losses'] 43 | y5 = hist['A_cycle_losses'] 44 | y6 = hist['B_cycle_losses'] 45 | 46 | 47 | plt.plot(x, y1, label='D_A_loss') 48 | plt.plot(x, y2, label='D_B_loss') 49 | plt.plot(x, y3, label='G_A_loss') 50 | plt.plot(x, y4, label='G_B_loss') 51 | plt.plot(x, y5, label='A_cycle_loss') 52 | plt.plot(x, y6, label='B_cycle_loss') 53 | 54 | plt.xlabel('Iter') 55 | plt.ylabel('Loss') 56 | 57 | plt.legend(loc=4) 58 | plt.grid(True) 59 | plt.tight_layout() 60 | 61 | if save: 62 | plt.savefig(path) 63 | 64 | if show: 65 | plt.show() 66 | else: 67 | plt.close() 68 | 69 | def generate_animation(root, model, opt): 70 | images = [] 71 | for e in range(opt.train_epoch): 72 | img_name = root + 'Fixed_results/' + model + str(e + 1) + '.png' 73 | images.append(imageio.imread(img_name)) 74 | imageio.mimsave(root + model + 'generate_animation.gif', images, fps=5) 75 | 76 | def data_load(path, subfolder, transform, batch_size, shuffle=False): 77 | dset = datasets.ImageFolder(path, transform) 78 | ind = dset.class_to_idx[subfolder] 79 | 80 | n = 0 81 | for i in range(dset.__len__()): 82 | if ind != dset.imgs[n][1]: 83 | del dset.imgs[n] 84 | n -= 1 85 | 86 | n += 1 87 | 88 | return torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=shuffle) 89 | 90 | def imgs_resize(imgs, resize_scale = 286): 91 | outputs = torch.FloatTensor(imgs.size()[0], imgs.size()[1], resize_scale, resize_scale) 92 | for i in range(imgs.size()[0]): 93 | img = imresize(imgs[i].numpy(), [resize_scale, resize_scale]) 94 | outputs[i] = torch.FloatTensor((img.transpose(2, 0, 1).astype(np.float32).reshape(-1, imgs.size()[1], resize_scale, resize_scale) - 127.5) / 127.5) 95 | 96 | return outputs 97 | 98 | def random_crop(imgs, crop_size = 256): 99 | outputs = torch.FloatTensor(imgs.size()[0], imgs.size()[1], crop_size, crop_size) 100 | for i in range(imgs.size()[0]): 101 | img = imgs[i] 102 | rand1 = np.random.randint(0, imgs.size()[2] - crop_size) 103 | rand2 = np.random.randint(0, imgs.size()[2] - crop_size) 104 | outputs[i] = img[:, rand1: crop_size + rand1, rand2: crop_size + rand2] 105 | 106 | return outputs 107 | 108 | def random_fliplr(imgs): 109 | outputs = torch.FloatTensor(imgs.size()) 110 | for i in range(imgs.size()[0]): 111 | if torch.rand(1)[0] < 0.5: 112 | img = torch.FloatTensor( 113 | (np.fliplr(imgs[i].numpy().transpose(1, 2, 0)).transpose(2, 0, 1).reshape(-1, imgs.size()[1], imgs.size()[2], imgs.size()[3]) + 1) / 2) 114 | outputs[i] = (img - 0.5) / 0.5 115 | else: 116 | outputs[i] = imgs[i] 117 | 118 | return outputs 119 | 120 | def print_network(net): 121 | num_params = 0 122 | for param in net.parameters(): 123 | num_params += param.numel() 124 | print(net) 125 | print('Total number of parameters: %d' % num_params) 126 | 127 | class image_store(): 128 | def __init__(self, store_size=50): 129 | self.store_size = store_size 130 | self.num_img = 0 131 | self.images = [] 132 | 133 | def query(self, image): 134 | select_imgs = [] 135 | for i in range(image.size()[0]): 136 | if self.num_img < self.store_size: 137 | self.images.append(image) 138 | select_imgs.append(image) 139 | self.num_img += 1 140 | else: 141 | prob = np.random.uniform(0, 1) 142 | if prob > 0.5: 143 | ind = np.random.randint(0, self.store_size - 1) 144 | select_imgs.append(self.images[ind]) 145 | self.images[ind] = image 146 | else: 147 | select_imgs.append(image) 148 | 149 | return Variable(torch.cat(select_imgs, 0)) 150 | 151 | class ImagePool(): 152 | def __init__(self, pool_size): 153 | self.pool_size = pool_size 154 | if self.pool_size > 0: 155 | self.num_imgs = 0 156 | self.images = [] 157 | 158 | def query(self, images): 159 | if self.pool_size == 0: 160 | return images 161 | return_images = [] 162 | for image in images.data: 163 | image = torch.unsqueeze(image, 0) 164 | if self.num_imgs < self.pool_size: 165 | self.num_imgs = self.num_imgs + 1 166 | self.images.append(image) 167 | return_images.append(image) 168 | else: 169 | p = random.uniform(0, 1) 170 | if p > 0.5: 171 | random_id = random.randint(0, self.pool_size-1) 172 | tmp = self.images[random_id].clone() 173 | self.images[random_id] = image 174 | return_images.append(tmp) 175 | else: 176 | return_images.append(image) 177 | return_images = Variable(torch.cat(return_images, 0)) 178 | return return_images -------------------------------------------------------------------------------- /pytorch_cycleGAN.py: -------------------------------------------------------------------------------- 1 | import os, time, pickle, argparse, network, util, itertools 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import matplotlib.pyplot as plt 6 | from torchvision import transforms 7 | from torch.autograd import Variable 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--dataset', required=False, default='apple2orange', help='') 11 | parser.add_argument('--train_subfolder', required=False, default='train', help='') 12 | parser.add_argument('--test_subfolder', required=False, default='test', help='') 13 | parser.add_argument('--input_ngc', type=int, default=3, help='input channel for generator') 14 | parser.add_argument('--output_ngc', type=int, default=3, help='output channel for generator') 15 | parser.add_argument('--input_ndc', type=int, default=3, help='input channel for discriminator') 16 | parser.add_argument('--output_ndc', type=int, default=1, help='output channel for discriminator') 17 | parser.add_argument('--batch_size', type=int, default=1, help='batch size') 18 | parser.add_argument('--ngf', type=int, default=32) 19 | parser.add_argument('--ndf', type=int, default=64) 20 | parser.add_argument('--nb', type=int, default=9, help='the number of resnet block layer for generator') 21 | parser.add_argument('--input_size', type=int, default=256, help='input size') 22 | parser.add_argument('--resize_scale', type=int, default=286, help='resize scale (0 is false)') 23 | parser.add_argument('--crop', type=bool, default=True, help='random crop True or False') 24 | parser.add_argument('--fliplr', type=bool, default=True, help='random fliplr True or False') 25 | parser.add_argument('--train_epoch', type=int, default=200, help='train epochs num') 26 | parser.add_argument('--decay_epoch', type=int, default=100, help='learning rate decay start epoch num') 27 | parser.add_argument('--lrD', type=float, default=0.0002, help='learning rate, default=0.0002') 28 | parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate, default=0.0002') 29 | parser.add_argument('--lambdaA', type=float, default=10, help='lambdaA for cycle loss') 30 | parser.add_argument('--lambdaB', type=float, default=10, help='lambdaB for cycle loss') 31 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer') 32 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer') 33 | parser.add_argument('--save_root', required=False, default='results', help='results save path') 34 | opt = parser.parse_args() 35 | print('------------ Options -------------') 36 | for k, v in sorted(vars(opt).items()): 37 | print('%s: %s' % (str(k), str(v))) 38 | print('-------------- End ----------------') 39 | 40 | # results save path 41 | root = opt.dataset + '_' + opt.save_root + '/' 42 | model = opt.dataset + '_' 43 | if not os.path.isdir(root): 44 | os.mkdir(root) 45 | if not os.path.isdir(root + 'test_results'): 46 | os.mkdir(root + 'test_results') 47 | if not os.path.isdir(root + 'test_results/AtoB'): 48 | os.mkdir(root + 'test_results/AtoB') 49 | if not os.path.isdir(root + 'test_results/BtoA'): 50 | os.mkdir(root + 'test_results/BtoA') 51 | 52 | # data_loader 53 | transform = transforms.Compose([ 54 | transforms.ToTensor(), 55 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 56 | ]) 57 | train_loader_A = util.data_load('data/' + opt.dataset, opt.train_subfolder + 'A', transform, opt.batch_size, shuffle=True) 58 | train_loader_B = util.data_load('data/' + opt.dataset, opt.train_subfolder + 'B', transform, opt.batch_size, shuffle=True) 59 | test_loader_A = util.data_load('data/' + opt.dataset, opt.test_subfolder + 'A', transform, opt.batch_size, shuffle=False) 60 | test_loader_B = util.data_load('data/' + opt.dataset, opt.test_subfolder + 'B', transform, opt.batch_size, shuffle=False) 61 | 62 | # network 63 | G_A = network.generator(opt.input_ngc, opt.output_ngc, opt.ngf, opt.nb) 64 | G_B = network.generator(opt.input_ngc, opt.output_ngc, opt.ngf, opt.nb) 65 | D_A = network.discriminator(opt.input_ndc, opt.output_ndc, opt.ndf) 66 | D_B = network.discriminator(opt.input_ndc, opt.output_ndc, opt.ndf) 67 | G_A.weight_init(mean=0.0, std=0.02) 68 | G_B.weight_init(mean=0.0, std=0.02) 69 | D_A.weight_init(mean=0.0, std=0.02) 70 | D_B.weight_init(mean=0.0, std=0.02) 71 | G_A.cuda() 72 | G_B.cuda() 73 | D_A.cuda() 74 | D_B.cuda() 75 | G_A.train() 76 | G_B.train() 77 | D_A.train() 78 | D_B.train() 79 | print('---------- Networks initialized -------------') 80 | util.print_network(G_A) 81 | util.print_network(G_B) 82 | util.print_network(D_A) 83 | util.print_network(D_B) 84 | print('-----------------------------------------------') 85 | 86 | # loss 87 | BCE_loss = nn.BCELoss().cuda() 88 | MSE_loss = nn.MSELoss().cuda() 89 | L1_loss = nn.L1Loss().cuda() 90 | 91 | # Adam optimizer 92 | G_optimizer = optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=opt.lrG, betas=(opt.beta1, opt.beta2)) 93 | D_A_optimizer = optim.Adam(D_A.parameters(), lr=opt.lrD, betas=(opt.beta1, opt.beta2)) 94 | D_B_optimizer = optim.Adam(D_B.parameters(), lr=opt.lrD, betas=(opt.beta1, opt.beta2)) 95 | 96 | # image store 97 | # fakeA_store = util.image_store(50) 98 | # fakeB_store = util.image_store(50) 99 | fakeA_store = util.ImagePool(50) 100 | fakeB_store = util.ImagePool(50) 101 | 102 | train_hist = {} 103 | train_hist['D_A_losses'] = [] 104 | train_hist['D_B_losses'] = [] 105 | train_hist['G_A_losses'] = [] 106 | train_hist['G_B_losses'] = [] 107 | train_hist['A_cycle_losses'] = [] 108 | train_hist['B_cycle_losses'] = [] 109 | train_hist['per_epoch_ptimes'] = [] 110 | train_hist['total_ptime'] = [] 111 | 112 | print('training start!') 113 | start_time = time.time() 114 | for epoch in range(opt.train_epoch): 115 | D_A_losses = [] 116 | D_B_losses = [] 117 | G_A_losses = [] 118 | G_B_losses = [] 119 | A_cycle_losses = [] 120 | B_cycle_losses = [] 121 | epoch_start_time = time.time() 122 | num_iter = 0 123 | if (epoch+1) > opt.decay_epoch: 124 | D_A_optimizer.param_groups[0]['lr'] -= opt.lrD / (opt.train_epoch - opt.decay_epoch) 125 | D_B_optimizer.param_groups[0]['lr'] -= opt.lrD / (opt.train_epoch - opt.decay_epoch) 126 | G_optimizer.param_groups[0]['lr'] -= opt.lrG / (opt.train_epoch - opt.decay_epoch) 127 | 128 | for (realA, _), (realB, _) in itertools.izip(train_loader_A, train_loader_B): 129 | 130 | if opt.resize_scale: 131 | realA = util.imgs_resize(realA, opt.resize_scale) 132 | realB = util.imgs_resize(realB, opt.resize_scale) 133 | 134 | if opt.crop: 135 | realA = util.random_crop(realA, opt.input_size) 136 | realB = util.random_crop(realB, opt.input_size) 137 | 138 | if opt.fliplr: 139 | realA = util.random_fliplr(realA) 140 | realB = util.random_fliplr(realB) 141 | 142 | realA, realB = Variable(realA.cuda()), Variable(realB.cuda()) 143 | 144 | # train generator G 145 | G_optimizer.zero_grad() 146 | 147 | # generate real A to fake B; D_A(G_A(A)) 148 | fakeB = G_A(realA) 149 | D_A_result = D_A(fakeB) 150 | G_A_loss = MSE_loss(D_A_result, Variable(torch.ones(D_A_result.size()).cuda())) 151 | 152 | # reconstruct fake B to rec A; G_B(G_A(A)) 153 | recA = G_B(fakeB) 154 | A_cycle_loss = L1_loss(recA, realA) * opt.lambdaA 155 | 156 | # generate real B to fake A; D_A(G_B(B)) 157 | fakeA = G_B(realB) 158 | D_B_result = D_B(fakeA) 159 | G_B_loss = MSE_loss(D_B_result, Variable(torch.ones(D_B_result.size()).cuda())) 160 | 161 | # reconstruct fake A to rec B G_A(G_B(B)) 162 | recB = G_A(fakeA) 163 | B_cycle_loss = L1_loss(recB, realB) * opt.lambdaB 164 | 165 | G_loss = G_A_loss + G_B_loss + A_cycle_loss + B_cycle_loss 166 | G_loss.backward() 167 | G_optimizer.step() 168 | 169 | train_hist['G_A_losses'].append(G_A_loss.data[0]) 170 | train_hist['G_B_losses'].append(G_B_loss.data[0]) 171 | train_hist['A_cycle_losses'].append(A_cycle_loss.data[0]) 172 | train_hist['B_cycle_losses'].append(B_cycle_loss.data[0]) 173 | G_A_losses.append(G_A_loss.data[0]) 174 | G_B_losses.append(G_B_loss.data[0]) 175 | A_cycle_losses.append(A_cycle_loss.data[0]) 176 | B_cycle_losses.append(B_cycle_loss.data[0]) 177 | 178 | # train discriminator D_A 179 | D_A_optimizer.zero_grad() 180 | 181 | D_A_real = D_A(realB) 182 | D_A_real_loss = MSE_loss(D_A_real, Variable(torch.ones(D_A_real.size()).cuda())) 183 | 184 | # fakeB = fakeB_store.query(fakeB.data) 185 | fakeB = fakeB_store.query(fakeB) 186 | D_A_fake = D_A(fakeB) 187 | D_A_fake_loss = MSE_loss(D_A_fake, Variable(torch.zeros(D_A_fake.size()).cuda())) 188 | 189 | D_A_loss = (D_A_real_loss + D_A_fake_loss) * 0.5 190 | D_A_loss.backward() 191 | D_A_optimizer.step() 192 | 193 | train_hist['D_A_losses'].append(D_A_loss.data[0]) 194 | D_A_losses.append(D_A_loss.data[0]) 195 | 196 | # train discriminator D_B 197 | D_B_optimizer.zero_grad() 198 | 199 | D_B_real = D_B(realA) 200 | D_B_real_loss = MSE_loss(D_B_real, Variable(torch.ones(D_B_real.size()).cuda())) 201 | 202 | # fakeA = fakeA_store.query(fakeA.data) 203 | fakeA = fakeA_store.query(fakeA) 204 | D_B_fake = D_B(fakeA) 205 | D_B_fake_loss = MSE_loss(D_B_fake, Variable(torch.zeros(D_B_fake.size()).cuda())) 206 | 207 | D_B_loss = (D_B_real_loss + D_B_fake_loss) * 0.5 208 | D_B_loss.backward() 209 | D_B_optimizer.step() 210 | 211 | train_hist['D_B_losses'].append(D_B_loss.data[0]) 212 | D_B_losses.append(D_B_loss.data[0]) 213 | 214 | num_iter += 1 215 | 216 | epoch_end_time = time.time() 217 | per_epoch_ptime = epoch_end_time - epoch_start_time 218 | train_hist['per_epoch_ptimes'].append(per_epoch_ptime) 219 | print( 220 | '[%d/%d] - ptime: %.2f, loss_D_A: %.3f, loss_D_B: %.3f, loss_G_A: %.3f, loss_G_B: %.3f, loss_A_cycle: %.3f, loss_B_cycle: %.3f' % ( 221 | (epoch + 1), opt.train_epoch, per_epoch_ptime, torch.mean(torch.FloatTensor(D_A_losses)), 222 | torch.mean(torch.FloatTensor(D_B_losses)), torch.mean(torch.FloatTensor(G_A_losses)), 223 | torch.mean(torch.FloatTensor(G_B_losses)), torch.mean(torch.FloatTensor(A_cycle_losses)), 224 | torch.mean(torch.FloatTensor(B_cycle_losses)))) 225 | 226 | 227 | if (epoch+1) % 10 == 0: 228 | # test A to B 229 | n = 0 230 | for realA, _ in test_loader_A: 231 | n += 1 232 | path = opt.dataset + '_results/test_results/AtoB/' + str(n) + '_input.png' 233 | plt.imsave(path, (realA[0].numpy().transpose(1, 2, 0) + 1) / 2) 234 | realA = Variable(realA.cuda(), volatile=True) 235 | genB = G_A(realA) 236 | path = opt.dataset + '_results/test_results/AtoB/' + str(n) + '_output.png' 237 | plt.imsave(path, (genB[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2) 238 | recA = G_B(genB) 239 | path = opt.dataset + '_results/test_results/AtoB/' + str(n) + '_recon.png' 240 | plt.imsave(path, (recA[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2) 241 | 242 | # test B to A 243 | n = 0 244 | for realB, _ in test_loader_B: 245 | n += 1 246 | path = opt.dataset + '_results/test_results/BtoA/' + str(n) + '_input.png' 247 | plt.imsave(path, (realB[0].numpy().transpose(1, 2, 0) + 1) / 2) 248 | realB = Variable(realB.cuda(), volatile=True) 249 | genA = G_B(realB) 250 | path = opt.dataset + '_results/test_results/BtoA/' + str(n) + '_output.png' 251 | plt.imsave(path, (genA[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2) 252 | recB = G_A(genA) 253 | path = opt.dataset + '_results/test_results/BtoA/' + str(n) + '_recon.png' 254 | plt.imsave(path, (recB[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2) 255 | else: 256 | n = 0 257 | for realA, _ in train_loader_A: 258 | n += 1 259 | path = opt.dataset + '_results/train_results/AtoB/' + str(n) + '_input.png' 260 | plt.imsave(path, (realA[0].numpy().transpose(1, 2, 0) + 1) / 2) 261 | realA = Variable(realA.cuda(), volatile=True) 262 | genB = G_A(realA) 263 | path = opt.dataset + '_results/train_results/AtoB/' + str(n) + '_output.png' 264 | plt.imsave(path, (genB[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2) 265 | recA = G_B(genB) 266 | path = opt.dataset + '_results/train_results/AtoB/' + str(n) + '_recon.png' 267 | plt.imsave(path, (recA[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2) 268 | if n > 9: 269 | break 270 | 271 | # test B to A 272 | n = 0 273 | for realB, _ in train_loader_B: 274 | n += 1 275 | path = opt.dataset + '_results/train_results/BtoA/' + str(n) + '_input.png' 276 | plt.imsave(path, (realB[0].numpy().transpose(1, 2, 0) + 1) / 2) 277 | realB = Variable(realB.cuda(), volatile=True) 278 | genA = G_B(realB) 279 | path = opt.dataset + '_results/train_results/BtoA/' + str(n) + '_output.png' 280 | plt.imsave(path, (genA[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2) 281 | recB = G_A(genA) 282 | path = opt.dataset + '_results/train_results/BtoA/' + str(n) + '_recon.png' 283 | plt.imsave(path, (recB[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2) 284 | if n > 9: 285 | break 286 | 287 | end_time = time.time() 288 | total_ptime = end_time - start_time 289 | train_hist['total_ptime'].append(total_ptime) 290 | 291 | print("Avg one epoch ptime: %.2f, total %d epochs ptime: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), opt.train_epoch, total_ptime)) 292 | print("Training finish!... save training results") 293 | torch.save(G_A.state_dict(), root + model + 'generatorA_param.pkl') 294 | torch.save(G_B.state_dict(), root + model + 'generatorB_param.pkl') 295 | torch.save(D_A.state_dict(), root + model + 'discriminatorA_param.pkl') 296 | torch.save(D_B.state_dict(), root + model + 'discriminatorB_param.pkl') 297 | with open(root + model + 'train_hist.pkl', 'wb') as f: 298 | pickle.dump(train_hist, f) 299 | 300 | util.show_train_hist(train_hist, save=True, path=root + model + 'train_hist.png') 301 | --------------------------------------------------------------------------------