├── images
├── apple2orange
│ ├── AtoB
│ │ ├── 1_input.png
│ │ ├── 1_output.png
│ │ ├── 1_recon.png
│ │ ├── 2_input.png
│ │ ├── 2_output.png
│ │ ├── 2_recon.png
│ │ ├── 3_input.png
│ │ ├── 3_output.png
│ │ ├── 3_recon.png
│ │ ├── 4_input.png
│ │ ├── 4_output.png
│ │ ├── 4_recon.png
│ │ ├── 5_input.png
│ │ ├── 5_output.png
│ │ └── 5_recon.png
│ └── BtoA
│ │ ├── 1_input.png
│ │ ├── 1_output.png
│ │ ├── 1_recon.png
│ │ ├── 2_input.png
│ │ ├── 2_output.png
│ │ ├── 2_recon.png
│ │ ├── 3_input.png
│ │ ├── 3_output.png
│ │ ├── 3_recon.png
│ │ ├── 4_input.png
│ │ ├── 4_output.png
│ │ ├── 4_recon.png
│ │ ├── 5_input.png
│ │ ├── 5_output.png
│ │ └── 5_recon.png
└── horse2zebra
│ ├── AtoB
│ ├── 1_input.png
│ ├── 1_output.png
│ ├── 1_recon.png
│ ├── 2_input.png
│ ├── 2_output.png
│ ├── 2_recon.png
│ ├── 3_input.png
│ ├── 3_output.png
│ ├── 3_recon.png
│ ├── 4_input.png
│ ├── 4_output.png
│ ├── 4_recon.png
│ ├── 5_input.png
│ ├── 5_output.png
│ └── 5_recon.png
│ └── BtoA
│ ├── 1_input.png
│ ├── 1_output.png
│ ├── 1_recon.png
│ ├── 2_input.png
│ ├── 2_output.png
│ ├── 2_recon.png
│ ├── 3_input.png
│ ├── 3_output.png
│ ├── 3_recon.png
│ ├── 4_input.png
│ ├── 4_output.png
│ ├── 4_recon.png
│ ├── 5_input.png
│ ├── 5_output.png
│ └── 5_recon.png
├── network.py
├── README.md
├── util.py
└── pytorch_cycleGAN.py
/images/apple2orange/AtoB/1_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/1_input.png
--------------------------------------------------------------------------------
/images/apple2orange/AtoB/1_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/1_output.png
--------------------------------------------------------------------------------
/images/apple2orange/AtoB/1_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/1_recon.png
--------------------------------------------------------------------------------
/images/apple2orange/AtoB/2_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/2_input.png
--------------------------------------------------------------------------------
/images/apple2orange/AtoB/2_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/2_output.png
--------------------------------------------------------------------------------
/images/apple2orange/AtoB/2_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/2_recon.png
--------------------------------------------------------------------------------
/images/apple2orange/AtoB/3_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/3_input.png
--------------------------------------------------------------------------------
/images/apple2orange/AtoB/3_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/3_output.png
--------------------------------------------------------------------------------
/images/apple2orange/AtoB/3_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/3_recon.png
--------------------------------------------------------------------------------
/images/apple2orange/AtoB/4_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/4_input.png
--------------------------------------------------------------------------------
/images/apple2orange/AtoB/4_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/4_output.png
--------------------------------------------------------------------------------
/images/apple2orange/AtoB/4_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/4_recon.png
--------------------------------------------------------------------------------
/images/apple2orange/AtoB/5_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/5_input.png
--------------------------------------------------------------------------------
/images/apple2orange/AtoB/5_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/5_output.png
--------------------------------------------------------------------------------
/images/apple2orange/AtoB/5_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/AtoB/5_recon.png
--------------------------------------------------------------------------------
/images/apple2orange/BtoA/1_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/1_input.png
--------------------------------------------------------------------------------
/images/apple2orange/BtoA/1_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/1_output.png
--------------------------------------------------------------------------------
/images/apple2orange/BtoA/1_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/1_recon.png
--------------------------------------------------------------------------------
/images/apple2orange/BtoA/2_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/2_input.png
--------------------------------------------------------------------------------
/images/apple2orange/BtoA/2_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/2_output.png
--------------------------------------------------------------------------------
/images/apple2orange/BtoA/2_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/2_recon.png
--------------------------------------------------------------------------------
/images/apple2orange/BtoA/3_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/3_input.png
--------------------------------------------------------------------------------
/images/apple2orange/BtoA/3_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/3_output.png
--------------------------------------------------------------------------------
/images/apple2orange/BtoA/3_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/3_recon.png
--------------------------------------------------------------------------------
/images/apple2orange/BtoA/4_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/4_input.png
--------------------------------------------------------------------------------
/images/apple2orange/BtoA/4_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/4_output.png
--------------------------------------------------------------------------------
/images/apple2orange/BtoA/4_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/4_recon.png
--------------------------------------------------------------------------------
/images/apple2orange/BtoA/5_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/5_input.png
--------------------------------------------------------------------------------
/images/apple2orange/BtoA/5_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/5_output.png
--------------------------------------------------------------------------------
/images/apple2orange/BtoA/5_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/apple2orange/BtoA/5_recon.png
--------------------------------------------------------------------------------
/images/horse2zebra/AtoB/1_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/1_input.png
--------------------------------------------------------------------------------
/images/horse2zebra/AtoB/1_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/1_output.png
--------------------------------------------------------------------------------
/images/horse2zebra/AtoB/1_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/1_recon.png
--------------------------------------------------------------------------------
/images/horse2zebra/AtoB/2_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/2_input.png
--------------------------------------------------------------------------------
/images/horse2zebra/AtoB/2_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/2_output.png
--------------------------------------------------------------------------------
/images/horse2zebra/AtoB/2_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/2_recon.png
--------------------------------------------------------------------------------
/images/horse2zebra/AtoB/3_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/3_input.png
--------------------------------------------------------------------------------
/images/horse2zebra/AtoB/3_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/3_output.png
--------------------------------------------------------------------------------
/images/horse2zebra/AtoB/3_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/3_recon.png
--------------------------------------------------------------------------------
/images/horse2zebra/AtoB/4_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/4_input.png
--------------------------------------------------------------------------------
/images/horse2zebra/AtoB/4_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/4_output.png
--------------------------------------------------------------------------------
/images/horse2zebra/AtoB/4_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/4_recon.png
--------------------------------------------------------------------------------
/images/horse2zebra/AtoB/5_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/5_input.png
--------------------------------------------------------------------------------
/images/horse2zebra/AtoB/5_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/5_output.png
--------------------------------------------------------------------------------
/images/horse2zebra/AtoB/5_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/AtoB/5_recon.png
--------------------------------------------------------------------------------
/images/horse2zebra/BtoA/1_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/1_input.png
--------------------------------------------------------------------------------
/images/horse2zebra/BtoA/1_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/1_output.png
--------------------------------------------------------------------------------
/images/horse2zebra/BtoA/1_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/1_recon.png
--------------------------------------------------------------------------------
/images/horse2zebra/BtoA/2_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/2_input.png
--------------------------------------------------------------------------------
/images/horse2zebra/BtoA/2_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/2_output.png
--------------------------------------------------------------------------------
/images/horse2zebra/BtoA/2_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/2_recon.png
--------------------------------------------------------------------------------
/images/horse2zebra/BtoA/3_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/3_input.png
--------------------------------------------------------------------------------
/images/horse2zebra/BtoA/3_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/3_output.png
--------------------------------------------------------------------------------
/images/horse2zebra/BtoA/3_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/3_recon.png
--------------------------------------------------------------------------------
/images/horse2zebra/BtoA/4_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/4_input.png
--------------------------------------------------------------------------------
/images/horse2zebra/BtoA/4_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/4_output.png
--------------------------------------------------------------------------------
/images/horse2zebra/BtoA/4_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/4_recon.png
--------------------------------------------------------------------------------
/images/horse2zebra/BtoA/5_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/5_input.png
--------------------------------------------------------------------------------
/images/horse2zebra/BtoA/5_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/5_output.png
--------------------------------------------------------------------------------
/images/horse2zebra/BtoA/5_recon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-CycleGAN/HEAD/images/horse2zebra/BtoA/5_recon.png
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class generator(nn.Module):
6 | # initializers
7 | def __init__(self, input_nc, output_nc, ngf=32, nb=6):
8 | super(generator, self).__init__()
9 | self.input_nc = input_nc
10 | self.output_nc = output_nc
11 | self.ngf = ngf
12 | self.nb = nb
13 | self.conv1 = nn.Conv2d(input_nc, ngf, 7, 1, 0)
14 | self.conv1_norm = nn.InstanceNorm2d(ngf)
15 | self.conv2 = nn.Conv2d(ngf, ngf * 2, 3, 2, 1)
16 | self.conv2_norm = nn.InstanceNorm2d(ngf * 2)
17 | self.conv3 = nn.Conv2d(ngf * 2, ngf * 4, 3, 2, 1)
18 | self.conv3_norm = nn.InstanceNorm2d(ngf * 4)
19 |
20 | self.resnet_blocks = []
21 | for i in range(nb):
22 | self.resnet_blocks.append(resnet_block(ngf * 4, 3, 1, 1))
23 | self.resnet_blocks[i].weight_init(0, 0.02)
24 |
25 | self.resnet_blocks = nn.Sequential(*self.resnet_blocks)
26 |
27 | self.deconv1 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, 1)
28 | self.deconv1_norm = nn.InstanceNorm2d(ngf * 2)
29 | self.deconv2 = nn.ConvTranspose2d(ngf * 2, ngf, 3, 2, 1, 1)
30 | self.deconv2_norm = nn.InstanceNorm2d(ngf)
31 | self.deconv3 = nn.Conv2d(ngf, output_nc, 7, 1, 0)
32 |
33 | # weight_init
34 | def weight_init(self, mean, std):
35 | for m in self._modules:
36 | normal_init(self._modules[m], mean, std)
37 |
38 | # forward method
39 | def forward(self, input):
40 | x = F.pad(input, (3, 3, 3, 3), 'reflect')
41 | x = F.relu(self.conv1_norm(self.conv1(x)))
42 | x = F.relu(self.conv2_norm(self.conv2(x)))
43 | x = F.relu(self.conv3_norm(self.conv3(x)))
44 | x = self.resnet_blocks(x)
45 | x = F.relu(self.deconv1_norm(self.deconv1(x)))
46 | x = F.relu(self.deconv2_norm(self.deconv2(x)))
47 | x = F.pad(x, (3, 3, 3, 3), 'reflect')
48 | o = F.tanh(self.deconv3(x))
49 |
50 | return o
51 |
52 | class discriminator(nn.Module):
53 | # initializers
54 | def __init__(self, input_nc, output_nc, ndf=64):
55 | super(discriminator, self).__init__()
56 | self.input_nc = input_nc
57 | self.output_nc = output_nc
58 | self.ndf = ndf
59 | self.conv1 = nn.Conv2d(input_nc, ndf, 4, 2, 1)
60 | self.conv2 = nn.Conv2d(ndf, ndf * 2, 4, 2, 1)
61 | self.conv2_norm = nn.InstanceNorm2d(ndf * 2)
62 | self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1)
63 | self.conv3_norm = nn.InstanceNorm2d(ndf * 4)
64 | self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, 4, 1, 1)
65 | self.conv4_norm = nn.InstanceNorm2d(ndf * 8)
66 | self.conv5 = nn.Conv2d(ndf * 8, output_nc, 4, 1, 1)
67 |
68 | # weight_init
69 | def weight_init(self, mean, std):
70 | for m in self._modules:
71 | normal_init(self._modules[m], mean, std)
72 |
73 | # forward method
74 | def forward(self, input):
75 | x = F.leaky_relu(self.conv1(input), 0.2)
76 | x = F.leaky_relu(self.conv2_norm(self.conv2(x)), 0.2)
77 | x = F.leaky_relu(self.conv3_norm(self.conv3(x)), 0.2)
78 | x = F.leaky_relu(self.conv4_norm(self.conv4(x)), 0.2)
79 | x = self.conv5(x)
80 |
81 | return x
82 |
83 | # resnet block with reflect padding
84 | class resnet_block(nn.Module):
85 | def __init__(self, channel, kernel, stride, padding):
86 | super(resnet_block, self).__init__()
87 | self.channel = channel
88 | self.kernel = kernel
89 | self.strdie = stride
90 | self.padding = padding
91 | self.conv1 = nn.Conv2d(channel, channel, kernel, stride, 0)
92 | self.conv1_norm = nn.InstanceNorm2d(channel)
93 | self.conv2 = nn.Conv2d(channel, channel, kernel, stride, 0)
94 | self.conv2_norm = nn.InstanceNorm2d(channel)
95 |
96 | # weight_init
97 | def weight_init(self, mean, std):
98 | for m in self._modules:
99 | normal_init(self._modules[m], mean, std)
100 |
101 | def forward(self, input):
102 | x = F.pad(input, (self.padding, self.padding, self.padding, self.padding), 'reflect')
103 | x = F.relu(self.conv1_norm(self.conv1(x)))
104 | x = F.pad(x, (self.padding, self.padding, self.padding, self.padding), 'reflect')
105 | x = self.conv2_norm(self.conv2(x))
106 |
107 | return input + x
108 |
109 | def normal_init(m, mean, std):
110 | if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
111 | m.weight.data.normal_(mean, std)
112 | m.bias.data.zero_()
113 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-CycleGAN
2 | Pytorch implementation of CycleGAN [1].
3 |
4 | * you can download datasets: https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/
5 | * you can see more information for network architecture and training details in https://arxiv.org/pdf/1703.10593.pdf
6 |
7 | ## dataset
8 | * apple2orange
9 | * apple training images: 995, orange training images: 1,019, apple test images: 266, orange test images: 248
10 | * horse2zebra
11 | * horse training images: 1,067, zebra training images: 1,334, horse test images: 120, zebra test images: 140
12 |
13 | ## Resutls
14 | ### apple2orange (after 200 epochs)
15 | * apple2orange
16 |
17 |
18 | | Input |
19 | Output |
20 | Reconstruction |
21 |
22 |
23 |
24 | |
25 | |
26 | |
27 |
28 |
29 | |
30 | |
31 | |
32 |
33 |
34 | |
35 | |
36 | |
37 |
38 |
39 | |
40 | |
41 | |
42 |
43 |
44 | |
45 | |
46 | |
47 |
48 |
49 | * orange2apple
50 |
51 |
52 | | Input |
53 | Output |
54 | Reconstruction |
55 |
56 |
57 |
58 | |
59 | |
60 | |
61 |
62 |
63 | |
64 | |
65 | |
66 |
67 |
68 | |
69 | |
70 | |
71 |
72 |
73 | |
74 | |
75 | |
76 |
77 |
78 | |
79 | |
80 | |
81 |
82 |
83 | * Learning Time
84 | * apple2orange - Avg. per epoch: 299.38 sec; Total 200 epochs: 62,225.33 sec
85 |
86 | ### horse2zebra (after 200 epochs)
87 | * horse2zebra
88 |
89 |
90 | | Input |
91 | Output |
92 | Reconstruction |
93 |
94 |
95 |
96 | |
97 | |
98 | |
99 |
100 |
101 | |
102 | |
103 | |
104 |
105 |
106 | |
107 | |
108 | |
109 |
110 |
111 | |
112 | |
113 | |
114 |
115 |
116 | |
117 | |
118 | |
119 |
120 |
121 | * zebra2horse
122 |
123 |
124 | | Input |
125 | Output |
126 | Reconstruction |
127 |
128 |
129 |
130 | |
131 | |
132 | |
133 |
134 |
135 | |
136 | |
137 | |
138 |
139 |
140 | |
141 | |
142 | |
143 |
144 |
145 | |
146 | |
147 | |
148 |
149 |
150 | |
151 | |
152 | |
153 |
154 |
155 | * Learning Time
156 | * horse2zebra - Avg. per epoch: 299.25 sec; Total 200 epochs: 61,221.27 sec
157 |
158 | ## Development Environment
159 |
160 | * Ubuntu 14.04 LTS
161 | * NVIDIA GTX 1080 ti
162 | * cuda 8.0
163 | * Python 2.7.6
164 | * pytorch 0.1.12
165 | * matplotlib 1.3.1
166 | * scipy 0.19.1
167 |
168 | ## Reference
169 |
170 | [1] Zhu, Jun-Yan, et al. "Unpaired image-to-image translation using cycle-consistent adversarial networks." arXiv preprint arXiv:1703.10593 (2017).
171 |
172 | (Full paper: https://arxiv.org/pdf/1703.10593.pdf)
173 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import itertools, imageio, torch, random
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | from torchvision import datasets
5 | from scipy.misc import imresize
6 | from torch.autograd import Variable
7 |
8 | def show_result(G, x_, y_, num_epoch, show = False, save = False, path = 'result.png'):
9 | test_images = G(x_)
10 |
11 | size_figure_grid = 3
12 | fig, ax = plt.subplots(x_.size()[0], size_figure_grid, figsize=(5, 5))
13 | for i, j in itertools.product(range(x_.size()[0]), range(size_figure_grid)):
14 | ax[i, j].get_xaxis().set_visible(False)
15 | ax[i, j].get_yaxis().set_visible(False)
16 |
17 | for i in range(x_.size()[0]):
18 | ax[i, 0].cla()
19 | ax[i, 0].imshow((x_[i].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
20 | ax[i, 1].cla()
21 | ax[i, 1].imshow((test_images[i].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
22 | ax[i, 2].cla()
23 | ax[i, 2].imshow((y_[i].numpy().transpose(1, 2, 0) + 1) / 2)
24 |
25 | label = 'Epoch {0}'.format(num_epoch)
26 | fig.text(0.5, 0.04, label, ha='center')
27 |
28 | if save:
29 | plt.savefig(path)
30 |
31 | if show:
32 | plt.show()
33 | else:
34 | plt.close()
35 |
36 | def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
37 | x = range(len(hist['D_A_losses']))
38 |
39 | y1 = hist['D_A_losses']
40 | y2 = hist['D_B_losses']
41 | y3 = hist['G_A_losses']
42 | y4 = hist['G_B_losses']
43 | y5 = hist['A_cycle_losses']
44 | y6 = hist['B_cycle_losses']
45 |
46 |
47 | plt.plot(x, y1, label='D_A_loss')
48 | plt.plot(x, y2, label='D_B_loss')
49 | plt.plot(x, y3, label='G_A_loss')
50 | plt.plot(x, y4, label='G_B_loss')
51 | plt.plot(x, y5, label='A_cycle_loss')
52 | plt.plot(x, y6, label='B_cycle_loss')
53 |
54 | plt.xlabel('Iter')
55 | plt.ylabel('Loss')
56 |
57 | plt.legend(loc=4)
58 | plt.grid(True)
59 | plt.tight_layout()
60 |
61 | if save:
62 | plt.savefig(path)
63 |
64 | if show:
65 | plt.show()
66 | else:
67 | plt.close()
68 |
69 | def generate_animation(root, model, opt):
70 | images = []
71 | for e in range(opt.train_epoch):
72 | img_name = root + 'Fixed_results/' + model + str(e + 1) + '.png'
73 | images.append(imageio.imread(img_name))
74 | imageio.mimsave(root + model + 'generate_animation.gif', images, fps=5)
75 |
76 | def data_load(path, subfolder, transform, batch_size, shuffle=False):
77 | dset = datasets.ImageFolder(path, transform)
78 | ind = dset.class_to_idx[subfolder]
79 |
80 | n = 0
81 | for i in range(dset.__len__()):
82 | if ind != dset.imgs[n][1]:
83 | del dset.imgs[n]
84 | n -= 1
85 |
86 | n += 1
87 |
88 | return torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=shuffle)
89 |
90 | def imgs_resize(imgs, resize_scale = 286):
91 | outputs = torch.FloatTensor(imgs.size()[0], imgs.size()[1], resize_scale, resize_scale)
92 | for i in range(imgs.size()[0]):
93 | img = imresize(imgs[i].numpy(), [resize_scale, resize_scale])
94 | outputs[i] = torch.FloatTensor((img.transpose(2, 0, 1).astype(np.float32).reshape(-1, imgs.size()[1], resize_scale, resize_scale) - 127.5) / 127.5)
95 |
96 | return outputs
97 |
98 | def random_crop(imgs, crop_size = 256):
99 | outputs = torch.FloatTensor(imgs.size()[0], imgs.size()[1], crop_size, crop_size)
100 | for i in range(imgs.size()[0]):
101 | img = imgs[i]
102 | rand1 = np.random.randint(0, imgs.size()[2] - crop_size)
103 | rand2 = np.random.randint(0, imgs.size()[2] - crop_size)
104 | outputs[i] = img[:, rand1: crop_size + rand1, rand2: crop_size + rand2]
105 |
106 | return outputs
107 |
108 | def random_fliplr(imgs):
109 | outputs = torch.FloatTensor(imgs.size())
110 | for i in range(imgs.size()[0]):
111 | if torch.rand(1)[0] < 0.5:
112 | img = torch.FloatTensor(
113 | (np.fliplr(imgs[i].numpy().transpose(1, 2, 0)).transpose(2, 0, 1).reshape(-1, imgs.size()[1], imgs.size()[2], imgs.size()[3]) + 1) / 2)
114 | outputs[i] = (img - 0.5) / 0.5
115 | else:
116 | outputs[i] = imgs[i]
117 |
118 | return outputs
119 |
120 | def print_network(net):
121 | num_params = 0
122 | for param in net.parameters():
123 | num_params += param.numel()
124 | print(net)
125 | print('Total number of parameters: %d' % num_params)
126 |
127 | class image_store():
128 | def __init__(self, store_size=50):
129 | self.store_size = store_size
130 | self.num_img = 0
131 | self.images = []
132 |
133 | def query(self, image):
134 | select_imgs = []
135 | for i in range(image.size()[0]):
136 | if self.num_img < self.store_size:
137 | self.images.append(image)
138 | select_imgs.append(image)
139 | self.num_img += 1
140 | else:
141 | prob = np.random.uniform(0, 1)
142 | if prob > 0.5:
143 | ind = np.random.randint(0, self.store_size - 1)
144 | select_imgs.append(self.images[ind])
145 | self.images[ind] = image
146 | else:
147 | select_imgs.append(image)
148 |
149 | return Variable(torch.cat(select_imgs, 0))
150 |
151 | class ImagePool():
152 | def __init__(self, pool_size):
153 | self.pool_size = pool_size
154 | if self.pool_size > 0:
155 | self.num_imgs = 0
156 | self.images = []
157 |
158 | def query(self, images):
159 | if self.pool_size == 0:
160 | return images
161 | return_images = []
162 | for image in images.data:
163 | image = torch.unsqueeze(image, 0)
164 | if self.num_imgs < self.pool_size:
165 | self.num_imgs = self.num_imgs + 1
166 | self.images.append(image)
167 | return_images.append(image)
168 | else:
169 | p = random.uniform(0, 1)
170 | if p > 0.5:
171 | random_id = random.randint(0, self.pool_size-1)
172 | tmp = self.images[random_id].clone()
173 | self.images[random_id] = image
174 | return_images.append(tmp)
175 | else:
176 | return_images.append(image)
177 | return_images = Variable(torch.cat(return_images, 0))
178 | return return_images
--------------------------------------------------------------------------------
/pytorch_cycleGAN.py:
--------------------------------------------------------------------------------
1 | import os, time, pickle, argparse, network, util, itertools
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import matplotlib.pyplot as plt
6 | from torchvision import transforms
7 | from torch.autograd import Variable
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--dataset', required=False, default='apple2orange', help='')
11 | parser.add_argument('--train_subfolder', required=False, default='train', help='')
12 | parser.add_argument('--test_subfolder', required=False, default='test', help='')
13 | parser.add_argument('--input_ngc', type=int, default=3, help='input channel for generator')
14 | parser.add_argument('--output_ngc', type=int, default=3, help='output channel for generator')
15 | parser.add_argument('--input_ndc', type=int, default=3, help='input channel for discriminator')
16 | parser.add_argument('--output_ndc', type=int, default=1, help='output channel for discriminator')
17 | parser.add_argument('--batch_size', type=int, default=1, help='batch size')
18 | parser.add_argument('--ngf', type=int, default=32)
19 | parser.add_argument('--ndf', type=int, default=64)
20 | parser.add_argument('--nb', type=int, default=9, help='the number of resnet block layer for generator')
21 | parser.add_argument('--input_size', type=int, default=256, help='input size')
22 | parser.add_argument('--resize_scale', type=int, default=286, help='resize scale (0 is false)')
23 | parser.add_argument('--crop', type=bool, default=True, help='random crop True or False')
24 | parser.add_argument('--fliplr', type=bool, default=True, help='random fliplr True or False')
25 | parser.add_argument('--train_epoch', type=int, default=200, help='train epochs num')
26 | parser.add_argument('--decay_epoch', type=int, default=100, help='learning rate decay start epoch num')
27 | parser.add_argument('--lrD', type=float, default=0.0002, help='learning rate, default=0.0002')
28 | parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate, default=0.0002')
29 | parser.add_argument('--lambdaA', type=float, default=10, help='lambdaA for cycle loss')
30 | parser.add_argument('--lambdaB', type=float, default=10, help='lambdaB for cycle loss')
31 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
32 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
33 | parser.add_argument('--save_root', required=False, default='results', help='results save path')
34 | opt = parser.parse_args()
35 | print('------------ Options -------------')
36 | for k, v in sorted(vars(opt).items()):
37 | print('%s: %s' % (str(k), str(v)))
38 | print('-------------- End ----------------')
39 |
40 | # results save path
41 | root = opt.dataset + '_' + opt.save_root + '/'
42 | model = opt.dataset + '_'
43 | if not os.path.isdir(root):
44 | os.mkdir(root)
45 | if not os.path.isdir(root + 'test_results'):
46 | os.mkdir(root + 'test_results')
47 | if not os.path.isdir(root + 'test_results/AtoB'):
48 | os.mkdir(root + 'test_results/AtoB')
49 | if not os.path.isdir(root + 'test_results/BtoA'):
50 | os.mkdir(root + 'test_results/BtoA')
51 |
52 | # data_loader
53 | transform = transforms.Compose([
54 | transforms.ToTensor(),
55 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
56 | ])
57 | train_loader_A = util.data_load('data/' + opt.dataset, opt.train_subfolder + 'A', transform, opt.batch_size, shuffle=True)
58 | train_loader_B = util.data_load('data/' + opt.dataset, opt.train_subfolder + 'B', transform, opt.batch_size, shuffle=True)
59 | test_loader_A = util.data_load('data/' + opt.dataset, opt.test_subfolder + 'A', transform, opt.batch_size, shuffle=False)
60 | test_loader_B = util.data_load('data/' + opt.dataset, opt.test_subfolder + 'B', transform, opt.batch_size, shuffle=False)
61 |
62 | # network
63 | G_A = network.generator(opt.input_ngc, opt.output_ngc, opt.ngf, opt.nb)
64 | G_B = network.generator(opt.input_ngc, opt.output_ngc, opt.ngf, opt.nb)
65 | D_A = network.discriminator(opt.input_ndc, opt.output_ndc, opt.ndf)
66 | D_B = network.discriminator(opt.input_ndc, opt.output_ndc, opt.ndf)
67 | G_A.weight_init(mean=0.0, std=0.02)
68 | G_B.weight_init(mean=0.0, std=0.02)
69 | D_A.weight_init(mean=0.0, std=0.02)
70 | D_B.weight_init(mean=0.0, std=0.02)
71 | G_A.cuda()
72 | G_B.cuda()
73 | D_A.cuda()
74 | D_B.cuda()
75 | G_A.train()
76 | G_B.train()
77 | D_A.train()
78 | D_B.train()
79 | print('---------- Networks initialized -------------')
80 | util.print_network(G_A)
81 | util.print_network(G_B)
82 | util.print_network(D_A)
83 | util.print_network(D_B)
84 | print('-----------------------------------------------')
85 |
86 | # loss
87 | BCE_loss = nn.BCELoss().cuda()
88 | MSE_loss = nn.MSELoss().cuda()
89 | L1_loss = nn.L1Loss().cuda()
90 |
91 | # Adam optimizer
92 | G_optimizer = optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=opt.lrG, betas=(opt.beta1, opt.beta2))
93 | D_A_optimizer = optim.Adam(D_A.parameters(), lr=opt.lrD, betas=(opt.beta1, opt.beta2))
94 | D_B_optimizer = optim.Adam(D_B.parameters(), lr=opt.lrD, betas=(opt.beta1, opt.beta2))
95 |
96 | # image store
97 | # fakeA_store = util.image_store(50)
98 | # fakeB_store = util.image_store(50)
99 | fakeA_store = util.ImagePool(50)
100 | fakeB_store = util.ImagePool(50)
101 |
102 | train_hist = {}
103 | train_hist['D_A_losses'] = []
104 | train_hist['D_B_losses'] = []
105 | train_hist['G_A_losses'] = []
106 | train_hist['G_B_losses'] = []
107 | train_hist['A_cycle_losses'] = []
108 | train_hist['B_cycle_losses'] = []
109 | train_hist['per_epoch_ptimes'] = []
110 | train_hist['total_ptime'] = []
111 |
112 | print('training start!')
113 | start_time = time.time()
114 | for epoch in range(opt.train_epoch):
115 | D_A_losses = []
116 | D_B_losses = []
117 | G_A_losses = []
118 | G_B_losses = []
119 | A_cycle_losses = []
120 | B_cycle_losses = []
121 | epoch_start_time = time.time()
122 | num_iter = 0
123 | if (epoch+1) > opt.decay_epoch:
124 | D_A_optimizer.param_groups[0]['lr'] -= opt.lrD / (opt.train_epoch - opt.decay_epoch)
125 | D_B_optimizer.param_groups[0]['lr'] -= opt.lrD / (opt.train_epoch - opt.decay_epoch)
126 | G_optimizer.param_groups[0]['lr'] -= opt.lrG / (opt.train_epoch - opt.decay_epoch)
127 |
128 | for (realA, _), (realB, _) in itertools.izip(train_loader_A, train_loader_B):
129 |
130 | if opt.resize_scale:
131 | realA = util.imgs_resize(realA, opt.resize_scale)
132 | realB = util.imgs_resize(realB, opt.resize_scale)
133 |
134 | if opt.crop:
135 | realA = util.random_crop(realA, opt.input_size)
136 | realB = util.random_crop(realB, opt.input_size)
137 |
138 | if opt.fliplr:
139 | realA = util.random_fliplr(realA)
140 | realB = util.random_fliplr(realB)
141 |
142 | realA, realB = Variable(realA.cuda()), Variable(realB.cuda())
143 |
144 | # train generator G
145 | G_optimizer.zero_grad()
146 |
147 | # generate real A to fake B; D_A(G_A(A))
148 | fakeB = G_A(realA)
149 | D_A_result = D_A(fakeB)
150 | G_A_loss = MSE_loss(D_A_result, Variable(torch.ones(D_A_result.size()).cuda()))
151 |
152 | # reconstruct fake B to rec A; G_B(G_A(A))
153 | recA = G_B(fakeB)
154 | A_cycle_loss = L1_loss(recA, realA) * opt.lambdaA
155 |
156 | # generate real B to fake A; D_A(G_B(B))
157 | fakeA = G_B(realB)
158 | D_B_result = D_B(fakeA)
159 | G_B_loss = MSE_loss(D_B_result, Variable(torch.ones(D_B_result.size()).cuda()))
160 |
161 | # reconstruct fake A to rec B G_A(G_B(B))
162 | recB = G_A(fakeA)
163 | B_cycle_loss = L1_loss(recB, realB) * opt.lambdaB
164 |
165 | G_loss = G_A_loss + G_B_loss + A_cycle_loss + B_cycle_loss
166 | G_loss.backward()
167 | G_optimizer.step()
168 |
169 | train_hist['G_A_losses'].append(G_A_loss.data[0])
170 | train_hist['G_B_losses'].append(G_B_loss.data[0])
171 | train_hist['A_cycle_losses'].append(A_cycle_loss.data[0])
172 | train_hist['B_cycle_losses'].append(B_cycle_loss.data[0])
173 | G_A_losses.append(G_A_loss.data[0])
174 | G_B_losses.append(G_B_loss.data[0])
175 | A_cycle_losses.append(A_cycle_loss.data[0])
176 | B_cycle_losses.append(B_cycle_loss.data[0])
177 |
178 | # train discriminator D_A
179 | D_A_optimizer.zero_grad()
180 |
181 | D_A_real = D_A(realB)
182 | D_A_real_loss = MSE_loss(D_A_real, Variable(torch.ones(D_A_real.size()).cuda()))
183 |
184 | # fakeB = fakeB_store.query(fakeB.data)
185 | fakeB = fakeB_store.query(fakeB)
186 | D_A_fake = D_A(fakeB)
187 | D_A_fake_loss = MSE_loss(D_A_fake, Variable(torch.zeros(D_A_fake.size()).cuda()))
188 |
189 | D_A_loss = (D_A_real_loss + D_A_fake_loss) * 0.5
190 | D_A_loss.backward()
191 | D_A_optimizer.step()
192 |
193 | train_hist['D_A_losses'].append(D_A_loss.data[0])
194 | D_A_losses.append(D_A_loss.data[0])
195 |
196 | # train discriminator D_B
197 | D_B_optimizer.zero_grad()
198 |
199 | D_B_real = D_B(realA)
200 | D_B_real_loss = MSE_loss(D_B_real, Variable(torch.ones(D_B_real.size()).cuda()))
201 |
202 | # fakeA = fakeA_store.query(fakeA.data)
203 | fakeA = fakeA_store.query(fakeA)
204 | D_B_fake = D_B(fakeA)
205 | D_B_fake_loss = MSE_loss(D_B_fake, Variable(torch.zeros(D_B_fake.size()).cuda()))
206 |
207 | D_B_loss = (D_B_real_loss + D_B_fake_loss) * 0.5
208 | D_B_loss.backward()
209 | D_B_optimizer.step()
210 |
211 | train_hist['D_B_losses'].append(D_B_loss.data[0])
212 | D_B_losses.append(D_B_loss.data[0])
213 |
214 | num_iter += 1
215 |
216 | epoch_end_time = time.time()
217 | per_epoch_ptime = epoch_end_time - epoch_start_time
218 | train_hist['per_epoch_ptimes'].append(per_epoch_ptime)
219 | print(
220 | '[%d/%d] - ptime: %.2f, loss_D_A: %.3f, loss_D_B: %.3f, loss_G_A: %.3f, loss_G_B: %.3f, loss_A_cycle: %.3f, loss_B_cycle: %.3f' % (
221 | (epoch + 1), opt.train_epoch, per_epoch_ptime, torch.mean(torch.FloatTensor(D_A_losses)),
222 | torch.mean(torch.FloatTensor(D_B_losses)), torch.mean(torch.FloatTensor(G_A_losses)),
223 | torch.mean(torch.FloatTensor(G_B_losses)), torch.mean(torch.FloatTensor(A_cycle_losses)),
224 | torch.mean(torch.FloatTensor(B_cycle_losses))))
225 |
226 |
227 | if (epoch+1) % 10 == 0:
228 | # test A to B
229 | n = 0
230 | for realA, _ in test_loader_A:
231 | n += 1
232 | path = opt.dataset + '_results/test_results/AtoB/' + str(n) + '_input.png'
233 | plt.imsave(path, (realA[0].numpy().transpose(1, 2, 0) + 1) / 2)
234 | realA = Variable(realA.cuda(), volatile=True)
235 | genB = G_A(realA)
236 | path = opt.dataset + '_results/test_results/AtoB/' + str(n) + '_output.png'
237 | plt.imsave(path, (genB[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
238 | recA = G_B(genB)
239 | path = opt.dataset + '_results/test_results/AtoB/' + str(n) + '_recon.png'
240 | plt.imsave(path, (recA[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
241 |
242 | # test B to A
243 | n = 0
244 | for realB, _ in test_loader_B:
245 | n += 1
246 | path = opt.dataset + '_results/test_results/BtoA/' + str(n) + '_input.png'
247 | plt.imsave(path, (realB[0].numpy().transpose(1, 2, 0) + 1) / 2)
248 | realB = Variable(realB.cuda(), volatile=True)
249 | genA = G_B(realB)
250 | path = opt.dataset + '_results/test_results/BtoA/' + str(n) + '_output.png'
251 | plt.imsave(path, (genA[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
252 | recB = G_A(genA)
253 | path = opt.dataset + '_results/test_results/BtoA/' + str(n) + '_recon.png'
254 | plt.imsave(path, (recB[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
255 | else:
256 | n = 0
257 | for realA, _ in train_loader_A:
258 | n += 1
259 | path = opt.dataset + '_results/train_results/AtoB/' + str(n) + '_input.png'
260 | plt.imsave(path, (realA[0].numpy().transpose(1, 2, 0) + 1) / 2)
261 | realA = Variable(realA.cuda(), volatile=True)
262 | genB = G_A(realA)
263 | path = opt.dataset + '_results/train_results/AtoB/' + str(n) + '_output.png'
264 | plt.imsave(path, (genB[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
265 | recA = G_B(genB)
266 | path = opt.dataset + '_results/train_results/AtoB/' + str(n) + '_recon.png'
267 | plt.imsave(path, (recA[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
268 | if n > 9:
269 | break
270 |
271 | # test B to A
272 | n = 0
273 | for realB, _ in train_loader_B:
274 | n += 1
275 | path = opt.dataset + '_results/train_results/BtoA/' + str(n) + '_input.png'
276 | plt.imsave(path, (realB[0].numpy().transpose(1, 2, 0) + 1) / 2)
277 | realB = Variable(realB.cuda(), volatile=True)
278 | genA = G_B(realB)
279 | path = opt.dataset + '_results/train_results/BtoA/' + str(n) + '_output.png'
280 | plt.imsave(path, (genA[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
281 | recB = G_A(genA)
282 | path = opt.dataset + '_results/train_results/BtoA/' + str(n) + '_recon.png'
283 | plt.imsave(path, (recB[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
284 | if n > 9:
285 | break
286 |
287 | end_time = time.time()
288 | total_ptime = end_time - start_time
289 | train_hist['total_ptime'].append(total_ptime)
290 |
291 | print("Avg one epoch ptime: %.2f, total %d epochs ptime: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), opt.train_epoch, total_ptime))
292 | print("Training finish!... save training results")
293 | torch.save(G_A.state_dict(), root + model + 'generatorA_param.pkl')
294 | torch.save(G_B.state_dict(), root + model + 'generatorB_param.pkl')
295 | torch.save(D_A.state_dict(), root + model + 'discriminatorA_param.pkl')
296 | torch.save(D_B.state_dict(), root + model + 'discriminatorB_param.pkl')
297 | with open(root + model + 'train_hist.pkl', 'wb') as f:
298 | pickle.dump(train_hist, f)
299 |
300 | util.show_train_hist(train_hist, save=True, path=root + model + 'train_hist.png')
301 |
--------------------------------------------------------------------------------