├── Advanced-BicycleGAN
├── dataloader.py
├── model.py
├── solver.py
├── train.py
└── util.py
├── README.md
├── dataloader.py
├── model.py
├── png
├── interpolation.png
├── kl_1.png
├── kl_2.png
├── model.png
├── random_sample.png
└── represent.png
├── solver.py
├── test.py
├── train.py
└── util.py
/Advanced-BicycleGAN/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import torchvision.transforms as Transforms
4 |
5 | import os
6 | from PIL import Image
7 |
8 | class Edges2Shoes(Dataset):
9 | def __init__(self, root, transform, mode='train'):
10 | self.root = root
11 | self.transform = transform
12 | self.mode = mode
13 |
14 | data_dir = os.path.join(root, mode)
15 | self.file_list = os.listdir(data_dir)
16 |
17 | def __len__(self):
18 | return len(self.file_list)
19 |
20 | def __getitem__(self, idx):
21 | img_path = os.path.join(self.root, self.mode, self.file_list[idx])
22 | img = Image.open(img_path)
23 | W, H = img.size[0], img.size[1]
24 |
25 | data = img.crop((0, 0, int(W / 2), H))
26 | ground_truth = img.crop((int(W / 2), 0, W, H))
27 |
28 | data = self.transform(data)
29 | ground_truth = self.transform(ground_truth)
30 |
31 | return (data, ground_truth)
32 |
33 | def data_loader(root, batch_size=1, shuffle=True, img_size=128, mode='train'):
34 | transform = Transforms.Compose([Transforms.Scale((img_size, img_size)),
35 | Transforms.ToTensor(),
36 | Transforms.Normalize(mean=(0.5, 0.5, 0.5),
37 | std=(0.5, 0.5, 0.5))
38 | ])
39 |
40 | dset = Edges2Shoes(root, transform, mode=mode)
41 |
42 | if batch_size == 'all':
43 | batch_size = len(dset)
44 |
45 | dloader = torch.utils.data.DataLoader(dset,
46 | batch_size=batch_size,
47 | shuffle=shuffle,
48 | num_workers=0,
49 | drop_last=True)
50 | dlen = len(dset)
51 |
52 | return dloader, dlen
--------------------------------------------------------------------------------
/Advanced-BicycleGAN/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | '''
5 | < ConvBlock >
6 | Small unit block consists of [convolution layer - normalization layer - non linearity layer]
7 |
8 | * Parameters
9 | 1. in_dim : Input dimension(channels number)
10 | 2. out_dim : Output dimension(channels number)
11 | 3. k : Kernel size(filter size)
12 | 4. s : stride
13 | 5. p : padding size
14 | 6. norm : If it is true add Instance Normalization layer, otherwise skip this layer
15 | 7. non_linear : You can choose between 'leaky_relu', 'relu', 'None'
16 | '''
17 | class ConvBlock(nn.Module):
18 | def __init__(self, in_dim, out_dim, k=4, s=2, p=1, norm=True, non_linear='leaky_relu'):
19 | super(ConvBlock, self).__init__()
20 | layers = []
21 |
22 | # Convolution Layer
23 | layers += [nn.Conv2d(in_dim, out_dim, kernel_size=k, stride=s, padding=p)]
24 |
25 | # Normalization Layer
26 | if norm is True:
27 | layers += [nn.InstanceNorm2d(out_dim, affine=True)]
28 |
29 | # Non-linearity Layer
30 | if non_linear == 'leaky_relu':
31 | layers += [nn.LeakyReLU(negative_slope=0.2, inplace=True)]
32 | elif non_linear == 'relu':
33 | layers += [nn.ReLU(inplace=True)]
34 |
35 | self.conv_block = nn.Sequential(* layers)
36 |
37 | def forward(self, x):
38 | out = self.conv_block(x)
39 | return out
40 |
41 | '''
42 | < DeonvBlock >
43 | Small unit block consists of [transpose conv layer - normalization layer - non linearity layer]
44 |
45 | * Parameters
46 | 1. in_dim : Input dimension(channels number)
47 | 2. out_dim : Output dimension(channels number)
48 | 3. k : Kernel size(filter size)
49 | 4. s : stride
50 | 5. p : padding size
51 | 6. norm : If it is true add Instance Normalization layer, otherwise skip this layer
52 | 7. non_linear : You can choose between 'relu', 'tanh', None
53 | '''
54 | class DeconvBlock(nn.Module):
55 | def __init__(self, in_dim, out_dim, k=4, s=2, p=1, norm=True, non_linear='relu'):
56 | super(DeconvBlock, self).__init__()
57 | layers = []
58 |
59 | # Transpose Convolution Layer
60 | layers += [nn.ConvTranspose2d(in_dim, out_dim, kernel_size=k, stride=s, padding=p)]
61 |
62 | # Normalization Layer
63 | if norm is True:
64 | layers += [nn.InstanceNorm2d(out_dim, affine=True)]
65 |
66 | # Non-Linearity Layer
67 | if non_linear == 'relu':
68 | layers += [nn.ReLU(inplace=True)]
69 | elif non_linear == 'tanh':
70 | layers += [nn.Tanh()]
71 |
72 | self.deconv_block = nn.Sequential(* layers)
73 |
74 | def forward(self, x):
75 | out = self.deconv_block(x)
76 | return out
77 |
78 | '''
79 | < Generator >
80 | U-Net Generator. See https://arxiv.org/abs/1505.04597 figure 1
81 | or https://arxiv.org/pdf/1611.07004 6.1.1 Generator Architectures
82 |
83 | Downsampled activation volume and upsampled activation volume which have same width and height
84 | make pairs and they are concatenated when upsampling.
85 | Pairs : (up_1, down_6) (up_2, down_5) (up_3, down_4) (up_4, down_3) (up_5, down_2) (up_6, down_1)
86 | down_7 doesn't have a partener.
87 |
88 | ex) up_1 and down_6 have same size of (N, 512, 2, 2) given that input size is (N, 3, 128, 128).
89 | When forwarding into upsample_2, up_1 and down_6 are concatenated to make (N, 1024, 2, 2) and then
90 | upsample_2 makes (N, 512, 4, 4). That is why upsample_2 has 1024 input dimension and 512 output dimension
91 |
92 | Except upsample_1, all the other upsampling blocks do the same thing.
93 | '''
94 | class Generator(nn.Module):
95 | def __init__(self, z_dim=8):
96 | super(Generator, self).__init__()
97 | # Reduce H and W by half at every downsampling
98 | self.downsample_1 = ConvBlock(3 + z_dim, 64, k=4, s=2, p=1, norm=False, non_linear='leaky_relu')
99 | self.downsample_2 = ConvBlock(64, 128, k=4, s=2, p=1, norm=True, non_linear='leaky_relu')
100 | self.downsample_3 = ConvBlock(128, 256, k=4, s=2, p=1, norm=True, non_linear='leaky_relu')
101 | self.downsample_4 = ConvBlock(256, 512, k=4, s=2, p=1, norm=True, non_linear='leaky_relu')
102 | self.downsample_5 = ConvBlock(512, 512, k=4, s=2, p=1, norm=True, non_linear='leaky_relu')
103 | self.downsample_6 = ConvBlock(512, 512, k=4, s=2, p=1, norm=True, non_linear='leaky_relu')
104 | self.downsample_7 = ConvBlock(512, 512, k=4, s=2, p=1, norm=True, non_linear='leaky_relu')
105 |
106 | # Need concatenation when upsampling, see foward function for details
107 | self.upsample_1 = DeconvBlock(512, 512, k=4, s=2, p=1, norm=True, non_linear='relu')
108 | self.upsample_2 = DeconvBlock(1024, 512, k=4, s=2, p=1, norm=True, non_linear='relu')
109 | self.upsample_3 = DeconvBlock(1024, 512, k=4, s=2, p=1, norm=True, non_linear='relu')
110 | self.upsample_4 = DeconvBlock(1024, 256, k=4, s=2, p=1, norm=True, non_linear='relu')
111 | self.upsample_5 = DeconvBlock(512, 128, k=4, s=2, p=1, norm=True, non_linear='relu')
112 | self.upsample_6 = DeconvBlock(256, 64, k=4, s=2, p=1, norm=True, non_linear='relu')
113 | self.upsample_7 = DeconvBlock(128, 3, k=4, s=2, p=1, norm=False, non_linear='Tanh')
114 |
115 | def forward(self, x, z):
116 | # z : (N, z_dim) -> (N, z_dim, 1, 1) -> (N, z_dim, H, W)
117 | # x_with_z : (N, 3 + z_dim, H, W)
118 | z = z.unsqueeze(dim=2).unsqueeze(dim=3)
119 | z = z.expand(z.size(0), z.size(1), x.size(2), x.size(3))
120 | x_with_z = torch.cat([x, z], dim=1)
121 |
122 | down_1 = self.downsample_1(x_with_z)
123 | down_2 = self.downsample_2(down_1)
124 | down_3 = self.downsample_3(down_2)
125 | down_4 = self.downsample_4(down_3)
126 | down_5 = self.downsample_5(down_4)
127 | down_6 = self.downsample_6(down_5)
128 | down_7 = self.downsample_7(down_6)
129 |
130 | up_1 = self.upsample_1(down_7)
131 | up_2 = self.upsample_2(torch.cat([up_1, down_6], dim=1))
132 | up_3 = self.upsample_3(torch.cat([up_2, down_5], dim=1))
133 | up_4 = self.upsample_4(torch.cat([up_3, down_4], dim=1))
134 | up_5 = self.upsample_5(torch.cat([up_4, down_3], dim=1))
135 | up_6 = self.upsample_6(torch.cat([up_5, down_2], dim=1))
136 | out = self.upsample_7(torch.cat([up_6, down_1], dim=1))
137 |
138 | return out
139 |
140 | '''
141 | < Discriminator >
142 |
143 | PatchGAN discriminator. See https://arxiv.org/pdf/1611.07004 6.1.2 Discriminator architectures.
144 | It uses two discriminator which have different output sizes(different local probabilities).
145 |
146 | Futhermore, it is conditional discriminator so input dimension is 6. You can make input by concatenating
147 | two images to make pair of Domain A image and Domain B image.
148 | There are two cases to concatenate, [Domain_A, Domain_B_ground_truth] and [Domain_A, Domain_B_generated]
149 |
150 | d_1 : (N, 6, 128, 128) -> (N, 1, 14, 14)
151 | d_2 : (N, 6, 128, 128) -> (N, 1, 30, 30)
152 |
153 | In training, the generator needs to fool both of d_1 and d_2 and it makes the generator more robust.
154 |
155 | '''
156 | class Discriminator(nn.Module):
157 | def __init__(self):
158 | super(Discriminator, self).__init__()
159 | # Discriminator with last patch (14x14)
160 | # (N, 6, 128, 128) -> (N, 1, 14, 14)
161 | self.d_1 = nn.Sequential(nn.AvgPool2d(kernel_size=3, stride=2, padding=0, count_include_pad=False),
162 | ConvBlock(6, 32, k=4, s=2, p=1, norm=False, non_linear='leaky_relu'),
163 | ConvBlock(32, 64, k=4, s=2, p=1, norm=True, non_linear='leaky-relu'),
164 | ConvBlock(64, 128, k=4, s=1, p=1, norm=True, non_linear='leaky-relu'),
165 | ConvBlock(128, 1, k=4, s=1, p=1, norm=False, non_linear=None))
166 |
167 | # Discriminator with last patch (30x30)
168 | # (N, 6, 128, 128) -> (N, 1, 30, 30)
169 | self.d_2 = nn.Sequential(ConvBlock(6, 64, k=4, s=2, p=1, norm=False, non_linear='leaky_relu'),
170 | ConvBlock(64, 128, k=4, s=2, p=1, norm=True, non_linear='leaky-relu'),
171 | ConvBlock(128, 256, k=4, s=1, p=1, norm=True, non_linear='leaky-relu'),
172 | ConvBlock(256, 1, k=4, s=1, p=1, norm=False, non_linear=None))
173 |
174 | def forward(self, x):
175 | out_1 = self.d_1(x)
176 | out_2 = self.d_2(x)
177 | return (out_1, out_2)
178 |
179 | '''
180 | < ResBlock >
181 |
182 | This residual block is different with the one we usaully know which consists of
183 | [conv - norm - act - conv - norm] and identity mapping(x -> x) for shortcut.
184 |
185 | Also spatial size is decreased by half because of AvgPool2d.
186 | '''
187 | class ResBlock(nn.Module):
188 | def __init__(self, in_dim, out_dim):
189 | super(ResBlock, self).__init__()
190 | self.conv = nn.Sequential(nn.InstanceNorm2d(in_dim, affine=True),
191 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
192 | nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=1, padding=1),
193 | nn.InstanceNorm2d(in_dim, affine=True),
194 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
195 | nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=1, padding=1),
196 | nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
197 |
198 | self.short_cut = nn.Sequential(nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
199 | nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=1, padding=0))
200 |
201 | def forward(self, x):
202 | out = self.conv(x) + self.short_cut(x)
203 | return out
204 |
205 | '''
206 | < Encoder >
207 |
208 | Output is mu and log(var) for reparameterization trick used in Variation Auto Encoder.
209 | Encoding is done in this order.
210 | 1. Use this encoder and get mu and log_var
211 | 2. std = exp(log(var / 2))
212 | 3. random_z = N(0, 1)
213 | 4. encoded_z = random_z * std + mu (Reparameterization trick)
214 | '''
215 | class Encoder(nn.Module):
216 | def __init__(self, z_dim=8):
217 | super(Encoder, self).__init__()
218 |
219 | self.conv = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)
220 | self.res_blocks = nn.Sequential(ResBlock(64, 128),
221 | ResBlock(128, 192),
222 | ResBlock(192, 256))
223 | self.pool_block = nn.Sequential(nn.LeakyReLU(negative_slope=0.2, inplace=True),
224 | nn.AvgPool2d(kernel_size=8, stride=8, padding=0))
225 |
226 | # Return mu and logvar for reparameterization trick
227 | self.fc_mu = nn.Linear(256, z_dim)
228 | self.fc_logvar = nn.Linear(256, z_dim)
229 |
230 | def forward(self, x):
231 | # (N, 3, 128, 128) -> (N, 64, 64, 64)
232 | out = self.conv(x)
233 | # (N, 64, 64, 64) -> (N, 128, 32, 32) -> (N, 192, 16, 16) -> (N, 256, 8, 8)
234 | out = self.res_blocks(out)
235 | # (N, 256, 8, 8) -> (N, 256, 1, 1)
236 | out = self.pool_block(out)
237 | # (N, 256, 1, 1) -> (N, 256)
238 | out = out.view(x.size(0), -1)
239 |
240 | # (N, 256) -> (N, z_dim) x 2
241 | mu = self.fc_mu(out)
242 | log_var = self.fc_logvar(out)
243 |
244 | return (mu, log_var)
--------------------------------------------------------------------------------
/Advanced-BicycleGAN/solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import torch.optim as optim
5 | import torchvision
6 |
7 | from dataloader import data_loader
8 | import model
9 | import util
10 |
11 | import os
12 |
13 | '''
14 | < mse_loss >
15 | Calculate mean squared error loss
16 |
17 | * Parameters
18 | score : Output of discriminator
19 | target : 1 for real and 0 for fake
20 | '''
21 | def mse_loss(score, target=1):
22 | dtype = type(score)
23 |
24 | if target == 1:
25 | label = util.var(torch.ones(score.size()), requires_grad=False)
26 | elif target == 0:
27 | label = util.var(torch.zeros(score.size()), requires_grad=False)
28 |
29 | criterion = nn.MSELoss()
30 | loss = criterion(score, label)
31 |
32 | return loss
33 |
34 | '''
35 | < L1_loss >
36 | Calculate L1 loss
37 |
38 | * Parameters
39 | pred : Output of network
40 | target : Ground truth
41 | '''
42 | def L1_loss(pred, target):
43 | return torch.mean(torch.abs(pred - target))
44 |
45 | def lr_decay_rule(epoch, start_decay=100, lr_decay=100):
46 | decay_rate = 1.0 - (max(0, epoch - start_decay) / float(lr_decay))
47 | return decay_rate
48 |
49 | class Solver():
50 | def __init__(self, root='data/edges2shoes', result_dir='result', weight_dir='weight', load_weight=False,
51 | batch_size=2, test_size=20, test_img_num=5, img_size=128, num_epoch=100, save_every=1000,
52 | lr=0.0002, beta_1=0.5, beta_2=0.999, lambda_kl=0.01, lambda_img=10, lambda_z=0.5, z_dim=8):
53 |
54 | # Data type(Can use GPU or not?)
55 | self.dtype = torch.cuda.FloatTensor
56 | if torch.cuda.is_available() is False:
57 | self.dtype = torch.FloatTensor
58 |
59 | # Data loader for training
60 | self.dloader, dlen = data_loader(root=root, batch_size=batch_size, shuffle=True,
61 | img_size=img_size, mode='train')
62 |
63 | # Data loader for test
64 | self.t_dloader, _ = data_loader(root=root, batch_size=test_size, shuffle=False,
65 | img_size=img_size, mode='val')
66 |
67 | # Models
68 | # D_cVAE is discriminator for cVAE-GAN(encoded vector z).
69 | # D_cLR is discriminator for cLR-GAN(random vector z).
70 | # Both of D_cVAE and D_cLR has two discriminators which have different output size((14x14) and (30x30)).
71 | # Totally, we have for discriminators now.
72 | self.D_cVAE = model.Discriminator().type(self.dtype)
73 | self.D_cLR = model.Discriminator().type(self.dtype)
74 | self.G = model.Generator(z_dim=z_dim).type(self.dtype)
75 | self.E = model.Encoder(z_dim=z_dim).type(self.dtype)
76 |
77 | # Optimizers
78 | self.optim_D_cVAE = optim.Adam(self.D_cVAE.parameters(), lr=lr, betas=(beta_1, beta_2))
79 | self.optim_D_cLR = optim.Adam(self.D_cLR.parameters(), lr=lr, betas=(beta_1, beta_2))
80 | self.optim_G = optim.Adam(self.G.parameters(), lr=lr, betas=(beta_1, beta_2))
81 | self.optim_E = optim.Adam(self.E.parameters(), lr=lr, betas=(beta_1, beta_2))
82 |
83 | # Optiminzer lr scheduler
84 | #self.optim_D_scheduler = optim.lr_scheduler.LambdaLR(self.optim_D, lr_lambda=lr_decay_rule)
85 | #self.optim_G_scheduler = optim.lr_scheduler.LambdaLR(self.optim_G, lr_lambda=lr_decay_rule)
86 | #self.optim_E_scheduler = optim.lr_scheduler.LambdaLR(self.optim_E, lr_lambda=lr_decay_rule)
87 |
88 | # fixed random_z for test
89 | self.fixed_z = util.var(torch.randn(test_size, test_img_num, z_dim))
90 |
91 | # Some hyperparameters
92 | self.z_dim = z_dim
93 | self.lambda_kl = lambda_kl
94 | self.lambda_img = lambda_img
95 | self.lambda_z = lambda_z
96 |
97 | # Extra things
98 | self.result_dir = result_dir
99 | self.weight_dir = weight_dir
100 | self.load_weight = load_weight
101 | self.test_img_num = test_img_num
102 | self.img_size = img_size
103 | self.start_epoch = 0
104 | self.num_epoch = num_epoch
105 | self.save_every = save_every
106 |
107 | '''
108 | < show_model >
109 | Print model architectures
110 | '''
111 | def show_model(self):
112 | print('=========================== Discriminator for cVAE ===========================')
113 | print(self.D_cVAE)
114 | print('=============================================================================\n\n')
115 | print('=========================== Discriminator for cLR ===========================')
116 | print(self.D_cLR)
117 | print('=============================================================================\n\n')
118 | print('================================= Generator =================================')
119 | print(self.G)
120 | print('=============================================================================\n\n')
121 | print('================================== Encoder ==================================')
122 | print(self.E)
123 | print('=============================================================================\n\n')
124 |
125 | '''
126 | < set_train_phase >
127 | Set training phase
128 | '''
129 | def set_train_phase(self):
130 | self.D_cVAE.train()
131 | self.D_cLR.train()
132 | self.G.train()
133 | self.E.train()
134 |
135 | '''
136 | < load_pretrained >
137 | If you want to continue to train, load pretrained weight
138 | '''
139 | def load_pretrained(self):
140 | self.D_cVAE.load_state_dict(torch.load(os.path.join(self.weight_dir, 'D_cVAE.pkl')))
141 | self.D_cLR.load_state_dict(torch.load(os.path.join(self.weight_dir, 'D_cLR.pkl')))
142 | self.G.load_state_dict(torch.load(os.path.join(self.weight_dir, 'G.pkl')))
143 | self.E.load_state_dict(torch.load(os.path.join(self.weight_dir, 'E.pkl')))
144 |
145 | log_file = open('log.txt', 'r')
146 | line = log_file.readline()
147 | self.start_epoch = int(line)
148 |
149 | '''
150 | < save_weight >
151 | Save weight
152 | '''
153 | def save_weight(self, epoch=None):
154 | if epoch is None:
155 | d_cVAE_name = 'D_cVAE.pkl'
156 | d_cLR_name = 'D_cLR.pkl'
157 | g_name = 'G.pkl'
158 | e_name = 'E.pkl'
159 | else:
160 | d_cVAE_name = '{epochs}-{name}'.format(epochs=str(epoch), name='D_cVAE.pkl')
161 | d_cLR_name = '{epochs}-{name}'.format(epochs=str(epoch), name='D_cLR.pkl')
162 | g_name = '{epochs}-{name}'.format(epochs=str(epoch), name='G.pkl')
163 | e_name = '{epochs}-{name}'.format(epochs=str(epoch), name='E.pkl')
164 |
165 | torch.save(self.D_cVAE.state_dict(), os.path.join(self.weight_dir, d_cVAE_name))
166 | torch.save(self.D_cVAE.state_dict(), os.path.join(self.weight_dir, d_cLR_name))
167 | torch.save(self.G.state_dict(), os.path.join(self.weight_dir, g_name))
168 | torch.save(self.E.state_dict(), os.path.join(self.weight_dir, e_name))
169 |
170 | '''
171 | < all_zero_grad >
172 | Set all optimizers' grad to zero
173 | '''
174 | def all_zero_grad(self):
175 | self.optim_D_cVAE.zero_grad()
176 | self.optim_D_cLR.zero_grad()
177 | self.optim_G.zero_grad()
178 | self.optim_E.zero_grad()
179 |
180 | '''
181 | < train >
182 | Train the D_cVAE, D_cLR, G and E
183 | '''
184 | def train(self):
185 | if self.load_weight is True:
186 | self.load_pretrained()
187 |
188 | self.set_train_phase()
189 | self.show_model()
190 |
191 | # Training Start!
192 | for epoch in range(self.start_epoch, self.num_epoch):
193 | for iters, (img, ground_truth) in enumerate(self.dloader):
194 | # img(2, 3, 128, 128) : Two images in Domain A. One for cVAE and another for cLR.
195 | # ground_truth(2, 3, 128, 128) : Two images Domain B. One for cVAE and another for cLR.
196 | img, ground_truth = util.var(img), util.var(ground_truth)
197 |
198 | # Seperate data for cVAE_GAN(using encoded z) and cLR_GAN(using random z)
199 | cVAE_data = {'img' : img[0].unsqueeze(dim=0), 'ground_truth' : ground_truth[0].unsqueeze(dim=0)}
200 | cLR_data = {'img' : img[1].unsqueeze(dim=0), 'ground_truth' : ground_truth[1].unsqueeze(dim=0)}
201 |
202 | ''' ----------------------------- 1. Train D ----------------------------- '''
203 | ####################### < Step 1. D loss in cVAE-GAN > #######################
204 |
205 | # Encoded latent vector
206 | mu, log_variance = self.E(cVAE_data['ground_truth'])
207 | std = torch.exp(log_variance / 2)
208 | random_z = util.var(torch.randn(1, self.z_dim))
209 | encoded_z = (random_z * std) + mu
210 |
211 | # Generate fake image
212 | fake_img_cVAE = self.G(cVAE_data['img'], encoded_z)
213 |
214 | real_pair_cVAE = torch.cat([cVAE_data['img'], cVAE_data['ground_truth']], dim=1)
215 | fake_pair_cVAE = torch.cat([cVAE_data['img'], fake_img_cVAE], dim=1)
216 |
217 | real_d_cVAE_1, real_d_cVAE_2 = self.D_cVAE(real_pair_cVAE)
218 | fake_d_cVAE_1, fake_d_cVAE_2 = self.D_cVAE(fake_pair_cVAE.detach())
219 |
220 | D_loss_cVAE_1 = mse_loss(real_d_cVAE_1, 1) + mse_loss(fake_d_cVAE_1, 0) # Small patch loss
221 | D_loss_cVAE_2 = mse_loss(real_d_cVAE_2, 1) + mse_loss(fake_d_cVAE_2, 0) # Big patch loss
222 |
223 | ####################### < Step 2. D loss in cLR-GAN > #######################
224 |
225 | # Generate fake image
226 | # Generated img using 'cVAE' data will be used to train D_'cLR'
227 | fake_img_cLR = self.G(cVAE_data['img'], random_z)
228 |
229 | real_pair_cLR = torch.cat([cLR_data['img'], cLR_data['ground_truth']], dim=1)
230 | fake_pair_cLR = torch.cat([cVAE_data['img'], fake_img_cLR], dim=1)
231 |
232 | # A_cVAE = Domain A image for cVAE, A_cLR = Domain A image for cVAE
233 | # B_cVAE = Domain B image for cVAE, B_cLR = Domain B image for cVAE
234 |
235 | # D_cVAE has to discriminate [A_cVAE, B_cVAE] vs [A_cVAE, G(A_cVAE, encoded_z)]
236 | # D_cLR has to discriminate [A_cLR, B_cLR] vs [A_cVAE, G(A_cVAE, random_z)]
237 |
238 | # This helps to generate more diverse images
239 | real_d_cLR_1, real_d_cLR_2 = self.D_cLR(real_pair_cLR)
240 | fake_d_cLR_1, fake_d_cLR_2 = self.D_cLR(fake_pair_cLR.detach())
241 |
242 | D_loss_cLR_1 = mse_loss(real_d_cLR_1, 1) + mse_loss(fake_d_cLR_1, 0) # Small patch loss
243 | D_loss_cLR_2 = mse_loss(real_d_cLR_2, 1) + mse_loss(fake_d_cLR_2, 0) # Big patch loss
244 |
245 | D_loss = D_loss_cVAE_1 + D_loss_cVAE_2 + D_loss_cLR_1 + D_loss_cLR_2
246 |
247 | # Update D
248 | self.all_zero_grad()
249 | D_loss.backward()
250 | self.optim_D_cVAE.step()
251 | self.optim_D_cLR.step()
252 |
253 | ''' ----------------------------- 2. Train G & E ----------------------------- '''
254 | ########### < Step 1. GAN loss to fool discriminator (cVAE_GAN and cLR_GAN) > ###########
255 |
256 | # Encoded latent vector
257 | mu, log_variance = self.E(cVAE_data['ground_truth'])
258 | std = torch.exp(log_variance / 2)
259 | random_z = util.var(torch.randn(1, self.z_dim))
260 | encoded_z = (random_z * std) + mu
261 |
262 | # Generate fake image
263 | fake_img_cVAE = self.G(cVAE_data['img'], encoded_z)
264 | fake_pair_cVAE = torch.cat([cVAE_data['img'], fake_img_cVAE], dim=1)
265 |
266 | # Fool D_cVAE
267 | fake_d_cVAE_1, fake_d_cVAE_2 = self.D_cVAE(fake_pair_cVAE)
268 |
269 | GAN_loss_cVAE_1 = mse_loss(fake_d_cVAE_1, 1) # Small patch loss
270 | GAN_loss_cVAE_2 = mse_loss(fake_d_cVAE_2, 1) # Big patch loss
271 |
272 | # Random latent vector and generate fake image
273 | random_z = util.var(torch.randn(1, self.z_dim))
274 | fake_img_cLR = self.G(cLR_data['img'], random_z)
275 | fake_pair_cLR = torch.cat([cLR_data['img'], fake_img_cLR], dim=1)
276 |
277 | # Fool D_cLR
278 | fake_d_cLR_1, fake_d_cLR_2 = self.D_cLR(fake_pair_cLR)
279 |
280 | GAN_loss_cLR_1 = mse_loss(fake_d_cLR_1, 1) # Small patch loss
281 | GAN_loss_cLR_2 = mse_loss(fake_d_cLR_2, 1) # Big patch loss
282 |
283 | G_GAN_loss = GAN_loss_cVAE_1 + GAN_loss_cVAE_2 + GAN_loss_cLR_1 + GAN_loss_cLR_2
284 |
285 | ################# < Step 2. KL-divergence with N(0, 1) (cVAE-GAN) > #################
286 |
287 | # See http://yunjey47.tistory.com/43 or Appendix B in the paper for details
288 | KL_div = self.lambda_kl * torch.sum(0.5 * (mu ** 2 + torch.exp(log_variance) - log_variance - 1))
289 |
290 | #### < Step 3. Reconstruction of ground truth image (|G(A, z) - B|) (cVAE-GAN) > ####
291 | img_recon_loss = self.lambda_img * L1_loss(fake_img_cVAE, cVAE_data['ground_truth'])
292 |
293 | EG_loss = G_GAN_loss + KL_div + img_recon_loss
294 | self.all_zero_grad()
295 | EG_loss.backward(retain_graph=True) # retain_graph=True for the next step 3. Train ONLY G
296 | self.optim_E.step()
297 | self.optim_G.step()
298 |
299 | ''' ----------------------------- 3. Train ONLY G ----------------------------- '''
300 | ##### < Step 1. Reconstrution of random latent code (|E(G(A, z)) - z|) (cLR-GAN) > #####
301 |
302 | # This step should update only G.
303 | # See https://github.com/junyanz/BicycleGAN/issues/5 for details.
304 | mu, log_variance = self.E(fake_img_cLR)
305 | z_recon_loss = L1_loss(mu, random_z)
306 |
307 | z_recon_loss = self.lambda_z * z_recon_loss
308 |
309 | self.all_zero_grad()
310 | z_recon_loss.backward()
311 | self.optim_G.step()
312 |
313 | log_file = open('log.txt', 'w')
314 | log_file.write(str(epoch))
315 |
316 | # Print error, save intermediate result image and weight
317 | if iters % self.save_every == 0:
318 | print('[Epoch : %d / Iters : %d] => D_loss : %f / G_GAN_loss : %f / KL_div : %f / img_recon_loss : %f / z_recon_loss : %f'\
319 | %(epoch, iters, D_loss.data[0], G_GAN_loss.data[0], KL_div.data[0], img_recon_loss.data[0], z_recon_loss.data[0]))
320 |
321 | # Save intermediate result image
322 | if os.path.exists(self.result_dir) is False:
323 | os.makedirs(self.result_dir)
324 |
325 | result_img = util.make_img(self.t_dloader, self.G, self.fixed_z,
326 | img_num=self.test_img_num, img_size=self.img_size)
327 |
328 | img_name = '{epoch}_{iters}.png'.format(epoch=epoch, iters=iters)
329 | img_path = os.path.join(self.result_dir, img_name)
330 |
331 | torchvision.utils.save_image(result_img, img_path, nrow=self.test_img_num+1)
332 |
333 | # Save intermediate weight
334 | if os.path.exists(self.weight_dir) is False:
335 | os.makedirs(self.weight_dir)
336 |
337 | self.save_weight()
338 |
339 | # Save weight at the end of every epoch
340 | self.save_weight(epoch=epoch)
341 |
--------------------------------------------------------------------------------
/Advanced-BicycleGAN/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import os
4 |
5 | from solver_v6 import Solver
6 |
7 | def main(args):
8 | solver = Solver(root = args.root,
9 | result_dir = args.result_dir,
10 | weight_dir = args.weight_dir,
11 | load_weight = args.load_weight,
12 | batch_size = args.batch_size,
13 | test_size = args.test_size,
14 | test_img_num = args.test_img_num,
15 | img_size = args.img_size,
16 | num_epoch = args.num_epoch,
17 | save_every = args.save_every,
18 | lr = args.lr,
19 | beta_1 = args.beta_1,
20 | beta_2 = args.beta_2,
21 | lambda_kl = args.lambda_kl,
22 | lambda_img = args.lambda_img,
23 | lambda_z = args.lambda_z,
24 | z_dim = args.z_dim)
25 |
26 | solver.train()
27 |
28 | if __name__ == '__main__':
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument('--root', type=str, default='data/edges2shoes',
31 | help='Data location')
32 | parser.add_argument('--result_dir', type=str, default='test',
33 | help='Result images location')
34 | parser.add_argument('--weight_dir', type=str, default='weight',
35 | help='Weight location')
36 | parser.add_argument('--batch_size', type=int, default=2,
37 | help='Training batch size')
38 | parser.add_argument('--test_size', type=int, default=20,
39 | help='Test batch size')
40 | parser.add_argument('--test_img_num', type=int, default=5,
41 | help='How many images do you want to generate?')
42 | parser.add_argument('--img_size', type=int, default=128,
43 | help='Image size')
44 | parser.add_argument('--lr', type=float, default=0.0002,
45 | help='Learning rate')
46 | parser.add_argument('--beta_1', type=float, default=0.5,
47 | help='Beta1 for Adam')
48 | parser.add_argument('--beta_2', type=float, default=0.999,
49 | help='Beta2 for Adam')
50 | parser.add_argument('--lambda_kl', type=float, default=0.01,
51 | help='Lambda for KL Divergence')
52 | parser.add_argument('--lambda_img', type=float, default=10,
53 | help='Lambda for image reconstruction')
54 | parser.add_argument('--lambda_z', type=float, default=0.5,
55 | help='Lambda for z reconstruction')
56 | parser.add_argument('--z_dim', type=int, default=8,
57 | help='Dimension of z')
58 | parser.add_argument('--num_epoch', type=int, default=100,
59 | help='Number of epoch')
60 | parser.add_argument('--save_every', type=int, default=1000,
61 | help='How often do you want to see the result?')
62 | parser.add_argument('--load_weight', action='store_true',
63 | help='Load weight or not')
64 |
65 | args = parser.parse_args()
66 | main(args)
--------------------------------------------------------------------------------
/Advanced-BicycleGAN/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 |
4 | '''
5 | < var >
6 | Convert tensor to Variable
7 | '''
8 | def var(tensor, requires_grad=True):
9 | if torch.cuda.is_available():
10 | dtype = torch.cuda.FloatTensor
11 | else:
12 | dtype = torch.FloatTensor
13 |
14 | var = Variable(tensor.type(dtype), requires_grad=requires_grad)
15 |
16 | return var
17 |
18 | '''
19 | < make_img >
20 | Generate images
21 |
22 | * Parameters
23 | dloader : Data loader for test data set
24 | G : Generator
25 | z : random_z(size = (N, img_num, z_dim))
26 | N : test img number / img_num : Number of images that you want to generate with one test img / z_dim : 8
27 | img_num : Number of images that you want to generate with one test img
28 | '''
29 | def make_img(dloader, G, z, img_num=5, img_size=128):
30 | if torch.cuda.is_available():
31 | dtype = torch.cuda.FloatTensor
32 | else:
33 | dtype = torch.FloatTensor
34 |
35 | dloader = iter(dloader)
36 | img, _ = dloader.next()
37 |
38 | N = img.size(0)
39 | img = var(img.type(dtype))
40 |
41 | result_img = torch.FloatTensor(N * (img_num + 1), 3, img_size, img_size).type(dtype)
42 |
43 | for i in range(N):
44 | # original image to the leftmost
45 | result_img[i * (img_num + 1)] = img[i].data
46 |
47 | # Insert generated images to the next of the original image
48 | for j in range(img_num):
49 | img_ = img[i].unsqueeze(dim=0)
50 | z_ = z[i, j, :].unsqueeze(dim=0)
51 |
52 | out_img = G(img_, z_)
53 | result_img[i * (img_num + 1) + j + 1] = out_img.data
54 |
55 |
56 | # [-1, 1] -> [0, 1]
57 | result_img = result_img / 2 + 0.5
58 |
59 | return result_img
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BicycleGAN-pytorch
2 | __Pytorch__ implementation of [BicycleGAN : Toward Multimodal Image-to-Image Translation](https://arxiv.org/abs/1711.11586).
3 |

4 |
5 | ## Result
6 | ### Edges2Shoes
7 | Image size is 128 x 128 and __normal discriminator__ is used, __not conditional discriminator__. You can check what the conditional discriminator is in ```Advanced-BicycleGAN``` in this repository. It generates slightly more diverse, clear and realistic images than the ones below.
8 |
9 | * Random sampling
10 | 
11 |
12 | * Linear interpolated sampling
13 | 
14 |
15 | ## Model description
16 | 
17 |
18 | ### cVAE-GAN
19 | cVAE-GAN is an __image reconstruction process.__ From this, the encoder can extract proper latent code z containing features of given image 'B'. Through this process, the generator can generate image which has features of 'B' but the generator also needs to be able to fool the discriminator. Futhermore, cVAE-GAN uses KL-divergence to make the generator be able to generate images using randomly sampled z from normal distribution at the test phase.
20 |
21 | ### cLR-GAN
22 | This is an __latent code reconstruction process.__ If many latent codes correspond to a same output mode, this is called mode collapse. The main purpose of cLR-GAN is to make invertible mapping between B and z. It leads to bijective consistency between latent encoding and output modes that is significant in preventing model from __mode collapse.__
23 |
24 | ## Prerequisites
25 | * [Python 3.5+](https://www.continuum.io/downloads)
26 | * [PyTorch 0.2.0](http://pytorch.org/)
27 |
28 | ## Training step
29 | Before getting started, suppose that we want to optmize G which can convert __domain A into B__.
30 |
31 | __real_B__ : A real image of domain B from training data set
32 | __fake_B__ : A fake image of domain B made by the generator
33 | __encoded_z__ : Latent code z made by the encoder
34 | __random_z__ : Latent code z sampled randomly from normal distribution
35 |
36 | __1. Optimize D__
37 | * Optimize D in cVAE-GAN using real_B and fake_B made with encoded_z(__Adversarial loss__).
38 | * Optimize D in cLR-GAN using real_B and fake_B made with random_z(__Adversarial loss__).
39 |
40 | __2. Optimize G or E__
41 | * Optimize G and E in cVAE-GAN using fake_B made with encoded_z(__Adversarial loss__).
42 | * Optimize G and E in cVAE-GAN using real_B and fake_B made with encoded_z(__Image reconstruction loss__).
43 | * Optimize E in cVAE-GAN using the encoder outputs, mu and log_variance(__KL-div loss__).
44 | * Optimize G in cLR-GAN using fake_B made with random_z(__Adversarial loss__).
45 |
46 | __3. Optimize ONLY G(Do not update E)__
47 | * Optimize G in cLR-GAN using random_z and the encoder output mu(__Latent code reconstruction loss__).
48 |
49 | ## Implementation details
50 |
51 | * __Multi discriminator__
52 | First, __two discriminators__ are used for __two different last output sizes(PatchGAN)__; 14x14 and 30x30, for the discriminator to learn from two different scales.
53 | Second, each discriminator from above have __two discriminators__ because of two images each made with __encoded_z(cVAE-GAN) and random_z(cLR-GAN)__ from N(mu, std) and N(0, 1) respectively. Two discriminators are better than just one discriminator for both distributions.
54 | Totally, __four discriminators__ are used; __(cVAE-GAN, 14x14), (cVAE-GAN, 30x30), (cLR-GAN, 14x14) and (cLR-GAN, 30x30).__
55 |
56 | * __Encoder__
57 | __E_ResNet__ is used, __not E_CNN__. Residual block in the encoder is slightly different with the usual one. Check ResBlock class and Encoder class in model.py.
58 |
59 | * __How to inject the latent code z to the generator__
60 | Inject __only to the input__ by concatenating, not to all intermediate layers
61 |
62 | * __Training data__
63 | Batch size is 1 for both cVAE-GAN and cLR-GAN which means that get two images from the dataloader and distribute to cVAE-GAN and cLR-GAN.
64 |
65 | * __How to encode with encoder__
66 | Encoder returns mu and log_variance. Reparameterization trick is used, so __encoded_z = random_z * std + mu__ such that __std = exp(log_variance / 2).__
67 |
68 | * __How to calculate KL divergence__
69 | Following formula is from [here](http://yunjey47.tistory.com/43). Also if you want to see simple and clean VAE code, you can check [here](https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/variational_auto_encoder/main.py).
70 | 
71 | From N(0, 1) get KL divergence, so it leads to following formula.
72 | 
73 |
74 | * __How to reconstruct z in cLR-GAN__
75 | mu and log_variance are derived from the encoder in cLR-GAN. Use __L1 loss between mu and random_z__, not encoded_z and random_z. The reasons are the followings or you can check [here](https://github.com/junyanz/BicycleGAN/issues/14).
76 |
77 | 1. cLR-GAN is for point estimation not distribution estimation.
78 | 2. If std is too big, L1 loss between encoded_z and random_z can be unstable.
79 |
80 | ## Dataset
81 | You can download many datasets for BicycleGAN from [here](https://github.com/junyanz/BicycleGAN/tree/master/datasets).
82 |
83 | * Training images : ```data/edges2shoes/train```
84 | * Test images : ```data/edges2shoes/test```
85 |
86 | ## How to use
87 | ### Train
88 | ```python train.py --root=data/edges2shoes --result_dir=result --weight_dir=weight```
89 |
90 | ### Test
91 | #### Random sample
92 | * Most recent
93 | ```python test.py --sample_type=random --root=data/edges2shoes --result_dir=test --weight_dir=weight --img_num=5```
94 |
95 | * Set epoch
96 | ```python test.py --sample_type=random --root=data/edges2shoes --result_dir=test --weight_dir=weight --img_num=5 --epoch=55```
97 |
98 | #### Interpolation
99 | * Most recent
100 | ```python test.py --sample_type=interpolation --root=data/edges2shoes --result_dir=test --weight_dir=weight --img_num=10```
101 |
102 | * Set epoch
103 | ```python test.py --sample_type=interpolation --root=data/edges2shoes --result_dir=test --weight_dir=weight --img_num=10 --epoch=55```
104 |
105 | ## Future work
106 | * Training with other datasets.
107 | * ~~New model using conditional discriminator is on the training now~~ Check ```Advanced-BicycleGAN```
108 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import torchvision.transforms as Transforms
4 |
5 | import os
6 | from PIL import Image
7 |
8 | class Edges2Shoes(Dataset):
9 | def __init__(self, root, transform, mode='train'):
10 | self.root = root
11 | self.transform = transform
12 | self.mode = mode
13 |
14 | data_dir = os.path.join(root, mode)
15 | self.file_list = os.listdir(data_dir)
16 |
17 | def __len__(self):
18 | return len(self.file_list)
19 |
20 | def __getitem__(self, idx):
21 | img_path = os.path.join(self.root, self.mode, self.file_list[idx])
22 | img = Image.open(img_path)
23 | W, H = img.size[0], img.size[1]
24 |
25 | data = img.crop((0, 0, int(W / 2), H))
26 | ground_truth = img.crop((int(W / 2), 0, W, H))
27 |
28 | data = self.transform(data)
29 | ground_truth = self.transform(ground_truth)
30 |
31 | return (data, ground_truth)
32 |
33 | def data_loader(root, batch_size=1, shuffle=True, img_size=128, mode='train'):
34 | transform = Transforms.Compose([Transforms.Scale((img_size, img_size)),
35 | Transforms.ToTensor(),
36 | Transforms.Normalize(mean=(0.5, 0.5, 0.5),
37 | std=(0.5, 0.5, 0.5))
38 | ])
39 |
40 | dset = Edges2Shoes(root, transform, mode=mode)
41 |
42 | if batch_size == 'all':
43 | batch_size = len(dset)
44 |
45 | dloader = torch.utils.data.DataLoader(dset,
46 | batch_size=batch_size,
47 | shuffle=shuffle,
48 | num_workers=0,
49 | drop_last=True)
50 | dlen = len(dset)
51 |
52 | return dloader, dlen
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | '''
5 | < ConvBlock >
6 | Small unit block consists of [convolution layer - normalization layer - non linearity layer]
7 |
8 | * Parameters
9 | 1. in_dim : Input dimension(channels number)
10 | 2. out_dim : Output dimension(channels number)
11 | 3. k : Kernel size(filter size)
12 | 4. s : stride
13 | 5. p : padding size
14 | 6. norm : If it is true add Instance Normalization layer, otherwise skip this layer
15 | 7. non_linear : You can choose between 'leaky_relu', 'relu', 'None'
16 | '''
17 | class ConvBlock(nn.Module):
18 | def __init__(self, in_dim, out_dim, k=4, s=2, p=1, norm=True, non_linear='leaky_relu'):
19 | super(ConvBlock, self).__init__()
20 | layers = []
21 |
22 | # Convolution Layer
23 | layers += [nn.Conv2d(in_dim, out_dim, kernel_size=k, stride=s, padding=p)]
24 |
25 | # Normalization Layer
26 | if norm is True:
27 | layers += [nn.InstanceNorm2d(out_dim, affine=True)]
28 |
29 | # Non-linearity Layer
30 | if non_linear == 'leaky_relu':
31 | layers += [nn.LeakyReLU(negative_slope=0.2, inplace=True)]
32 | elif non_linear == 'relu':
33 | layers += [nn.ReLU(inplace=True)]
34 |
35 | self.conv_block = nn.Sequential(* layers)
36 |
37 | def forward(self, x):
38 | out = self.conv_block(x)
39 | return out
40 |
41 | '''
42 | < DeonvBlock >
43 | Small unit block consists of [transpose conv layer - normalization layer - non linearity layer]
44 |
45 | * Parameters
46 | 1. in_dim : Input dimension(channels number)
47 | 2. out_dim : Output dimension(channels number)
48 | 3. k : Kernel size(filter size)
49 | 4. s : stride
50 | 5. p : padding size
51 | 6. norm : If it is true add Instance Normalization layer, otherwise skip this layer
52 | 7. non_linear : You can choose between 'relu', 'tanh', None
53 | '''
54 | class DeconvBlock(nn.Module):
55 | def __init__(self, in_dim, out_dim, k=4, s=2, p=1, norm=True, non_linear='relu'):
56 | super(DeconvBlock, self).__init__()
57 | layers = []
58 |
59 | # Transpose Convolution Layer
60 | layers += [nn.ConvTranspose2d(in_dim, out_dim, kernel_size=k, stride=s, padding=p)]
61 |
62 | # Normalization Layer
63 | if norm is True:
64 | layers += [nn.InstanceNorm2d(out_dim, affine=True)]
65 |
66 | # Non-Linearity Layer
67 | if non_linear == 'relu':
68 | layers += [nn.ReLU(inplace=True)]
69 | elif non_linear == 'tanh':
70 | layers += [nn.Tanh()]
71 |
72 | self.deconv_block = nn.Sequential(* layers)
73 |
74 | def forward(self, x):
75 | out = self.deconv_block(x)
76 | return out
77 |
78 | '''
79 | < Generator >
80 | U-Net Generator. See https://arxiv.org/abs/1505.04597 figure 1
81 | or https://arxiv.org/pdf/1611.07004 6.1.1 Generator Architectures
82 |
83 | Downsampled activation volume and upsampled activation volume which have same width and height
84 | make pairs and they are concatenated when upsampling.
85 | Pairs : (up_1, down_6) (up_2, down_5) (up_3, down_4) (up_4, down_3) (up_5, down_2) (up_6, down_1)
86 | down_7 doesn't have a partener.
87 |
88 | ex) up_1 and down_6 have same size of (N, 512, 2, 2) given that input size is (N, 3, 128, 128).
89 | When forwarding into upsample_2, up_1 and down_6 are concatenated to make (N, 1024, 2, 2) and then
90 | upsample_2 makes (N, 512, 4, 4). That is why upsample_2 has 1024 input dimension and 512 output dimension
91 |
92 | Except upsample_1, all the other upsampling blocks do the same thing.
93 | '''
94 | class Generator(nn.Module):
95 | def __init__(self, z_dim=8):
96 | super(Generator, self).__init__()
97 | # Reduce H and W by half at every downsampling
98 | self.downsample_1 = ConvBlock(3 + z_dim, 64, k=4, s=2, p=1, norm=False, non_linear='leaky_relu')
99 | self.downsample_2 = ConvBlock(64, 128, k=4, s=2, p=1, norm=True, non_linear='leaky_relu')
100 | self.downsample_3 = ConvBlock(128, 256, k=4, s=2, p=1, norm=True, non_linear='leaky_relu')
101 | self.downsample_4 = ConvBlock(256, 512, k=4, s=2, p=1, norm=True, non_linear='leaky_relu')
102 | self.downsample_5 = ConvBlock(512, 512, k=4, s=2, p=1, norm=True, non_linear='leaky_relu')
103 | self.downsample_6 = ConvBlock(512, 512, k=4, s=2, p=1, norm=True, non_linear='leaky_relu')
104 | self.downsample_7 = ConvBlock(512, 512, k=4, s=2, p=1, norm=True, non_linear='leaky_relu')
105 |
106 | # Need concatenation when upsampling, see foward function for details
107 | self.upsample_1 = DeconvBlock(512, 512, k=4, s=2, p=1, norm=True, non_linear='relu')
108 | self.upsample_2 = DeconvBlock(1024, 512, k=4, s=2, p=1, norm=True, non_linear='relu')
109 | self.upsample_3 = DeconvBlock(1024, 512, k=4, s=2, p=1, norm=True, non_linear='relu')
110 | self.upsample_4 = DeconvBlock(1024, 256, k=4, s=2, p=1, norm=True, non_linear='relu')
111 | self.upsample_5 = DeconvBlock(512, 128, k=4, s=2, p=1, norm=True, non_linear='relu')
112 | self.upsample_6 = DeconvBlock(256, 64, k=4, s=2, p=1, norm=True, non_linear='relu')
113 | self.upsample_7 = DeconvBlock(128, 3, k=4, s=2, p=1, norm=False, non_linear='Tanh')
114 |
115 | def forward(self, x, z):
116 | # z : (N, z_dim) -> (N, z_dim, 1, 1) -> (N, z_dim, H, W)
117 | # x_with_z : (N, 3 + z_dim, H, W)
118 | z = z.unsqueeze(dim=2).unsqueeze(dim=3)
119 | z = z.expand(z.size(0), z.size(1), x.size(2), x.size(3))
120 | x_with_z = torch.cat([x, z], dim=1)
121 |
122 | down_1 = self.downsample_1(x_with_z)
123 | down_2 = self.downsample_2(down_1)
124 | down_3 = self.downsample_3(down_2)
125 | down_4 = self.downsample_4(down_3)
126 | down_5 = self.downsample_5(down_4)
127 | down_6 = self.downsample_6(down_5)
128 | down_7 = self.downsample_7(down_6)
129 |
130 | up_1 = self.upsample_1(down_7)
131 | up_2 = self.upsample_2(torch.cat([up_1, down_6], dim=1))
132 | up_3 = self.upsample_3(torch.cat([up_2, down_5], dim=1))
133 | up_4 = self.upsample_4(torch.cat([up_3, down_4], dim=1))
134 | up_5 = self.upsample_5(torch.cat([up_4, down_3], dim=1))
135 | up_6 = self.upsample_6(torch.cat([up_5, down_2], dim=1))
136 | out = self.upsample_7(torch.cat([up_6, down_1], dim=1))
137 |
138 | return out
139 |
140 | '''
141 | < Discriminator >
142 |
143 | PatchGAN discriminator. See https://arxiv.org/pdf/1611.07004 6.1.2 Discriminator architectures.
144 | It uses two discriminator which have different output sizes(different local probabilities).
145 | d_1 : (N, 3, 128, 128) -> (N, 1, 14, 14)
146 | d_2 : (N, 3, 128, 128) -> (N, 1, 30, 30)
147 |
148 | In training, the generator needs to fool both of d_1 and d_2 and it makes the generator more robust.
149 |
150 | '''
151 | class Discriminator(nn.Module):
152 | def __init__(self):
153 | super(Discriminator, self).__init__()
154 | # Discriminator with last patch (14x14)
155 | # (N, 3, 128, 128) -> (N, 1, 14, 14)
156 | self.d_1 = nn.Sequential(nn.AvgPool2d(kernel_size=3, stride=2, padding=0, count_include_pad=False),
157 | ConvBlock(3, 32, k=4, s=2, p=1, norm=False, non_linear='leaky_relu'),
158 | ConvBlock(32, 64, k=4, s=2, p=1, norm=True, non_linear='leaky-relu'),
159 | ConvBlock(64, 128, k=4, s=1, p=1, norm=True, non_linear='leaky-relu'),
160 | ConvBlock(128, 1, k=4, s=1, p=1, norm=False, non_linear=None))
161 |
162 | # Discriminator with last patch (30x30)
163 | # (N, 3, 128, 128) -> (N, 1, 30, 30)
164 | self.d_2 = nn.Sequential(ConvBlock(3, 64, k=4, s=2, p=1, norm=False, non_linear='leaky_relu'),
165 | ConvBlock(64, 128, k=4, s=2, p=1, norm=True, non_linear='leaky-relu'),
166 | ConvBlock(128, 256, k=4, s=1, p=1, norm=True, non_linear='leaky-relu'),
167 | ConvBlock(256, 1, k=4, s=1, p=1, norm=False, non_linear=None))
168 |
169 | def forward(self, x):
170 | out_1 = self.d_1(x)
171 | out_2 = self.d_2(x)
172 | return (out_1, out_2)
173 |
174 | '''
175 | < ResBlock >
176 |
177 | This residual block is different with the one we usaully know which consists of
178 | [conv - norm - act - conv - norm] and identity mapping(x -> x) for shortcut.
179 |
180 | Also spatial size is decreased by half because of AvgPool2d.
181 | '''
182 | class ResBlock(nn.Module):
183 | def __init__(self, in_dim, out_dim):
184 | super(ResBlock, self).__init__()
185 | self.conv = nn.Sequential(nn.InstanceNorm2d(in_dim, affine=True),
186 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
187 | nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=1, padding=1),
188 | nn.InstanceNorm2d(in_dim, affine=True),
189 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
190 | nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=1, padding=1),
191 | nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
192 |
193 | self.short_cut = nn.Sequential(nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
194 | nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=1, padding=0))
195 |
196 | def forward(self, x):
197 | out = self.conv(x) + self.short_cut(x)
198 | return out
199 |
200 | '''
201 | < Encoder >
202 |
203 | Output is mu and log(var) for reparameterization trick used in Variation Auto Encoder.
204 | Encoding is done in this order.
205 | 1. Use this encoder and get mu and log_var
206 | 2. std = exp(log(var / 2))
207 | 3. random_z = N(0, 1)
208 | 4. encoded_z = random_z * std + mu (Reparameterization trick)
209 | '''
210 | class Encoder(nn.Module):
211 | def __init__(self, z_dim=8):
212 | super(Encoder, self).__init__()
213 |
214 | self.conv = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)
215 | self.res_blocks = nn.Sequential(ResBlock(64, 128),
216 | ResBlock(128, 192),
217 | ResBlock(192, 256))
218 | self.pool_block = nn.Sequential(nn.LeakyReLU(negative_slope=0.2, inplace=True),
219 | nn.AvgPool2d(kernel_size=8, stride=8, padding=0))
220 |
221 | # Return mu and logvar for reparameterization trick
222 | self.fc_mu = nn.Linear(256, z_dim)
223 | self.fc_logvar = nn.Linear(256, z_dim)
224 |
225 | def forward(self, x):
226 | # (N, 3, 128, 128) -> (N, 64, 64, 64)
227 | out = self.conv(x)
228 | # (N, 64, 64, 64) -> (N, 128, 32, 32) -> (N, 192, 16, 16) -> (N, 256, 8, 8)
229 | out = self.res_blocks(out)
230 | # (N, 256, 8, 8) -> (N, 256, 1, 1)
231 | out = self.pool_block(out)
232 | # (N, 256, 1, 1) -> (N, 256)
233 | out = out.view(x.size(0), -1)
234 |
235 | # (N, 256) -> (N, z_dim) x 2
236 | mu = self.fc_mu(out)
237 | log_var = self.fc_logvar(out)
238 |
239 | return (mu, log_var)
--------------------------------------------------------------------------------
/png/interpolation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eveningglow/BicycleGAN-pytorch/c4419e04052396e2b3815ed50112236a077c04cb/png/interpolation.png
--------------------------------------------------------------------------------
/png/kl_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eveningglow/BicycleGAN-pytorch/c4419e04052396e2b3815ed50112236a077c04cb/png/kl_1.png
--------------------------------------------------------------------------------
/png/kl_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eveningglow/BicycleGAN-pytorch/c4419e04052396e2b3815ed50112236a077c04cb/png/kl_2.png
--------------------------------------------------------------------------------
/png/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eveningglow/BicycleGAN-pytorch/c4419e04052396e2b3815ed50112236a077c04cb/png/model.png
--------------------------------------------------------------------------------
/png/random_sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eveningglow/BicycleGAN-pytorch/c4419e04052396e2b3815ed50112236a077c04cb/png/random_sample.png
--------------------------------------------------------------------------------
/png/represent.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eveningglow/BicycleGAN-pytorch/c4419e04052396e2b3815ed50112236a077c04cb/png/represent.png
--------------------------------------------------------------------------------
/solver.py:
--------------------------------------------------------------------------------
1 | '''
2 | If you have any difficulties in following this code,
3 | training step and implementation detail section in README.md might be helpful.
4 | '''
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.autograd import Variable
9 | import torch.optim as optim
10 | import torchvision
11 |
12 | from dataloader import data_loader
13 | import model
14 | import util
15 |
16 | import os
17 |
18 | '''
19 | < mse_loss >
20 | Calculate mean squared error loss
21 |
22 | * Parameters
23 | score : Output of discriminator
24 | target : 1 for real and 0 for fake
25 | '''
26 | def mse_loss(score, target=1):
27 | dtype = type(score)
28 |
29 | if target == 1:
30 | label = util.var(torch.ones(score.size()), requires_grad=False)
31 | elif target == 0:
32 | label = util.var(torch.zeros(score.size()), requires_grad=False)
33 |
34 | criterion = nn.MSELoss()
35 | loss = criterion(score, label)
36 |
37 | return loss
38 |
39 | '''
40 | < L1_loss >
41 | Calculate L1 loss
42 |
43 | * Parameters
44 | pred : Output of network
45 | target : Ground truth
46 | '''
47 | def L1_loss(pred, target):
48 | return torch.mean(torch.abs(pred - target))
49 |
50 | def lr_decay_rule(epoch, start_decay=100, lr_decay=100):
51 | decay_rate = 1.0 - (max(0, epoch - start_decay) / float(lr_decay))
52 | return decay_rate
53 |
54 | class Solver():
55 | def __init__(self, root='data/edges2shoes', result_dir='result', weight_dir='weight', load_weight=False,
56 | batch_size=2, test_size=20, test_img_num=5, img_size=128, num_epoch=100, save_every=1000,
57 | lr=0.0002, beta_1=0.5, beta_2=0.999, lambda_kl=0.01, lambda_img=10, lambda_z=0.5, z_dim=8):
58 |
59 | # Data type(Can use GPU or not?)
60 | self.dtype = torch.cuda.FloatTensor
61 | if torch.cuda.is_available() is False:
62 | self.dtype = torch.FloatTensor
63 |
64 | # Data loader for training
65 | self.dloader, dlen = data_loader(root=root, batch_size=batch_size, shuffle=True,
66 | img_size=img_size, mode='train')
67 |
68 | # Data loader for test
69 | self.t_dloader, _ = data_loader(root=root, batch_size=test_size, shuffle=False,
70 | img_size=img_size, mode='val')
71 |
72 | # Both of D_cVAE and D_cLR has two discriminators which have different output size((14x14) and (30x30)).
73 | # Totally, we have for discriminators now.
74 | self.D_cVAE = model.Discriminator().type(self.dtype)
75 | self.D_cLR = model.Discriminator().type(self.dtype)
76 | self.G = model.Generator(z_dim=z_dim).type(self.dtype)
77 | self.E = model.Encoder(z_dim=z_dim).type(self.dtype)
78 |
79 | # Optimizers
80 | self.optim_D_cVAE = optim.Adam(self.D_cVAE.parameters(), lr=lr, betas=(beta_1, beta_2))
81 | self.optim_D_cLR = optim.Adam(self.D_cLR.parameters(), lr=lr, betas=(beta_1, beta_2))
82 | self.optim_G = optim.Adam(self.G.parameters(), lr=lr, betas=(beta_1, beta_2))
83 | self.optim_E = optim.Adam(self.E.parameters(), lr=lr, betas=(beta_1, beta_2))
84 |
85 | # fixed random_z for intermediate test
86 | self.fixed_z = util.var(torch.randn(test_size, test_img_num, z_dim))
87 |
88 | # Some hyperparameters
89 | self.z_dim = z_dim
90 | self.lambda_kl = lambda_kl
91 | self.lambda_img = lambda_img
92 | self.lambda_z = lambda_z
93 |
94 | # Extra things
95 | self.result_dir = result_dir
96 | self.weight_dir = weight_dir
97 | self.load_weight = load_weight
98 | self.test_img_num = test_img_num
99 | self.img_size = img_size
100 | self.start_epoch = 0
101 | self.num_epoch = num_epoch
102 | self.save_every = save_every
103 |
104 | '''
105 | < set_train_phase >
106 | Set training phase
107 | '''
108 | def set_train_phase(self):
109 | self.D_cVAE.train()
110 | self.D_cLR.train()
111 | self.G.train()
112 | self.E.train()
113 |
114 | '''
115 | < load_pretrained >
116 | If you want to continue to train, load pretrained weight
117 | '''
118 | def load_pretrained(self):
119 | self.D_cVAE.load_state_dict(torch.load(os.path.join(self.weight_dir, 'D_cVAE.pkl')))
120 | self.D_cLR.load_state_dict(torch.load(os.path.join(self.weight_dir, 'D_cLR.pkl')))
121 | self.G.load_state_dict(torch.load(os.path.join(self.weight_dir, 'G.pkl')))
122 | self.E.load_state_dict(torch.load(os.path.join(self.weight_dir, 'E.pkl')))
123 |
124 | log_file = open('log.txt', 'r')
125 | line = log_file.readline()
126 | self.start_epoch = int(line)
127 |
128 | '''
129 | < save_weight >
130 | Save weight
131 | '''
132 | def save_weight(self, epoch=None):
133 | if epoch is None:
134 | d_cVAE_name = 'D_cVAE.pkl'
135 | d_cLR_name = 'D_cLR.pkl'
136 | g_name = 'G.pkl'
137 | e_name = 'E.pkl'
138 | else:
139 | d_cVAE_name = '{epochs}-{name}'.format(epochs=str(epoch), name='D_cVAE.pkl')
140 | d_cLR_name = '{epochs}-{name}'.format(epochs=str(epoch), name='D_cLR.pkl')
141 | g_name = '{epochs}-{name}'.format(epochs=str(epoch), name='G.pkl')
142 | e_name = '{epochs}-{name}'.format(epochs=str(epoch), name='E.pkl')
143 |
144 | torch.save(self.D_cVAE.state_dict(), os.path.join(self.weight_dir, d_cVAE_name))
145 | torch.save(self.D_cVAE.state_dict(), os.path.join(self.weight_dir, d_cLR_name))
146 | torch.save(self.G.state_dict(), os.path.join(self.weight_dir, g_name))
147 | torch.save(self.E.state_dict(), os.path.join(self.weight_dir, e_name))
148 |
149 | '''
150 | < all_zero_grad >
151 | Set all optimizers' grad to zero
152 | '''
153 | def all_zero_grad(self):
154 | self.optim_D_cVAE.zero_grad()
155 | self.optim_D_cLR.zero_grad()
156 | self.optim_G.zero_grad()
157 | self.optim_E.zero_grad()
158 |
159 | '''
160 | < train >
161 | Train the D_cVAE, D_cLR, G and E
162 | '''
163 | def train(self):
164 | if self.load_weight is True:
165 | self.load_pretrained()
166 |
167 | self.set_train_phase()
168 |
169 | for epoch in range(self.start_epoch, self.num_epoch):
170 | for iters, (img, ground_truth) in enumerate(self.dloader):
171 | # img : (2, 3, 128, 128) of domain A / ground_truth : (2, 3, 128, 128) of domain B
172 | img, ground_truth = util.var(img), util.var(ground_truth)
173 |
174 | # Seperate data for cVAE_GAN and cLR_GAN
175 | cVAE_data = {'img' : img[0].unsqueeze(dim=0), 'ground_truth' : ground_truth[0].unsqueeze(dim=0)}
176 | cLR_data = {'img' : img[1].unsqueeze(dim=0), 'ground_truth' : ground_truth[1].unsqueeze(dim=0)}
177 |
178 | ''' ----------------------------- 1. Train D ----------------------------- '''
179 | ############# Step 1. D loss in cVAE-GAN #############
180 |
181 | # Encoded latent vector
182 | mu, log_variance = self.E(cVAE_data['ground_truth'])
183 | std = torch.exp(log_variance / 2)
184 | random_z = util.var(torch.randn(1, self.z_dim))
185 | encoded_z = (random_z * std) + mu
186 |
187 | # Generate fake image
188 | fake_img_cVAE = self.G(cVAE_data['img'], encoded_z)
189 |
190 | # Get scores and loss
191 | real_d_cVAE_1, real_d_cVAE_2 = self.D_cVAE(cVAE_data['ground_truth'])
192 | fake_d_cVAE_1, fake_d_cVAE_2 = self.D_cVAE(fake_img_cVAE)
193 |
194 | # mse_loss for LSGAN
195 | D_loss_cVAE_1 = mse_loss(real_d_cVAE_1, 1) + mse_loss(fake_d_cVAE_1, 0)
196 | D_loss_cVAE_2 = mse_loss(real_d_cVAE_2, 1) + mse_loss(fake_d_cVAE_2, 0)
197 |
198 | ############# Step 2. D loss in cLR-GAN #############
199 |
200 | # Random latent vector
201 | random_z = util.var(torch.randn(1, self.z_dim))
202 |
203 | # Generate fake image
204 | fake_img_cLR = self.G(cLR_data['img'], random_z)
205 |
206 | # Get scores and loss
207 | real_d_cLR_1, real_d_cLR_2 = self.D_cLR(cLR_data['ground_truth'])
208 | fake_d_cLR_1, fake_d_cLR_2 = self.D_cLR(fake_img_cLR)
209 |
210 | D_loss_cLR_1 = mse_loss(real_d_cLR_1, 1) + mse_loss(fake_d_cLR_1, 0)
211 | D_loss_cLR_2 = mse_loss(real_d_cLR_2, 1) + mse_loss(fake_d_cLR_2, 0)
212 |
213 | D_loss = D_loss_cVAE_1 + D_loss_cLR_1 + D_loss_cVAE_2 + D_loss_cLR_2
214 |
215 | # Update
216 | self.all_zero_grad()
217 | D_loss.backward()
218 | self.optim_D_cVAE.step()
219 | self.optim_D_cLR.step()
220 |
221 | ''' ----------------------------- 2. Train G & E ----------------------------- '''
222 | ############# Step 1. GAN loss to fool discriminator (cVAE_GAN and cLR_GAN) #############
223 |
224 | # Encoded latent vector
225 | mu, log_variance = self.E(cVAE_data['ground_truth'])
226 | std = torch.exp(log_variance / 2)
227 | random_z = util.var(torch.randn(1, self.z_dim))
228 | encoded_z = (random_z * std) + mu
229 |
230 | # Generate fake image and get adversarial loss
231 | fake_img_cVAE = self.G(cVAE_data['img'], encoded_z)
232 | fake_d_cVAE_1, fake_d_cVAE_2 = self.D_cVAE(fake_img_cVAE)
233 |
234 | GAN_loss_cVAE_1 = mse_loss(fake_d_cVAE_1, 1)
235 | GAN_loss_cVAE_2 = mse_loss(fake_d_cVAE_2, 1)
236 |
237 | # Random latent vector
238 | random_z = util.var(torch.randn(1, self.z_dim))
239 |
240 | # Generate fake image and get adversarial loss
241 | fake_img_cLR = self.G(cLR_data['img'], random_z)
242 | fake_d_cLR_1, fake_d_cLR_2 = self.D_cLR(fake_img_cLR)
243 |
244 | GAN_loss_cLR_1 = mse_loss(fake_d_cLR_1, 1)
245 | GAN_loss_cLR_2 = mse_loss(fake_d_cLR_2, 1)
246 |
247 | G_GAN_loss = GAN_loss_cVAE_1 + GAN_loss_cVAE_2 + GAN_loss_cLR_1 + GAN_loss_cLR_2
248 |
249 | ############# Step 2. KL-divergence with N(0, 1) (cVAE-GAN) #############
250 |
251 | KL_div = self.lambda_kl * torch.sum(0.5 * (mu ** 2 + torch.exp(log_variance) - log_variance - 1))
252 |
253 | ############# Step 3. Reconstruction of ground truth image (|G(A, z) - B|) (cVAE-GAN) #############
254 | img_recon_loss = self.lambda_img * L1_loss(fake_img_cVAE, cVAE_data['ground_truth'])
255 |
256 | EG_loss = G_GAN_loss + KL_div + img_recon_loss
257 | self.all_zero_grad()
258 | EG_loss.backward(retain_graph=True)
259 | self.optim_E.step()
260 | self.optim_G.step()
261 |
262 | ''' ----------------------------- 3. Train ONLY G ----------------------------- '''
263 | ############ Step 1. Reconstrution of random latent code (|E(G(A, z)) - z|) (cLR-GAN) ############
264 |
265 | # This step should update ONLY G.
266 | mu_, log_variance_ = self.E(fake_img_cLR)
267 | z_recon_loss = L1_loss(mu_, random_z)
268 |
269 | G_alone_loss = self.lambda_z * z_recon_loss
270 |
271 | self.all_zero_grad()
272 | G_alone_loss.backward()
273 | self.optim_G.step()
274 |
275 | log_file = open('log.txt', 'w')
276 | log_file.write(str(epoch))
277 |
278 | # Print error and save intermediate result image and weight
279 | if iters % self.save_every == 0:
280 | print('[Epoch : %d / Iters : %d] => D_loss : %f / G_GAN_loss : %f / KL_div : %f / img_recon_loss : %f / z_recon_loss : %f'\
281 | %(epoch, iters, D_loss.data[0], G_GAN_loss.data[0], KL_div.data[0], img_recon_loss.data[0], G_alone_loss.data[0]))
282 |
283 | # Save intermediate result image
284 | if os.path.exists(self.result_dir) is False:
285 | os.makedirs(self.result_dir)
286 |
287 | result_img = util.make_img(self.t_dloader, self.G, self.fixed_z,
288 | img_num=self.test_img_num, img_size=self.img_size)
289 |
290 | img_name = '{epoch}_{iters}.png'.format(epoch=epoch, iters=iters)
291 | img_path = os.path.join(self.result_dir, img_name)
292 |
293 | torchvision.utils.save_image(result_img, img_path, nrow=self.test_img_num+1)
294 |
295 | # Save intermediate weight
296 | if os.path.exists(self.weight_dir) is False:
297 | os.makedirs(self.weight_dir)
298 |
299 | self.save_weight()
300 |
301 | # Save weight at the end of every epoch
302 | self.save_weight(epoch=epoch)
303 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 |
4 | from dataloader import data_loader
5 | import model
6 | import util
7 |
8 | import os
9 | import numpy as np
10 | import argparse
11 |
12 | '''
13 | < make_interpolation >
14 | Make linear interpolated latent code.
15 |
16 | * Parameters
17 | n : Input images number
18 | img_num : Generated images number per one input image
19 | z_dim : Dimension of latent code. Basically 8.
20 | '''
21 | def make_interpolation(n=200, img_num=9, z_dim=8):
22 | if torch.cuda.is_available() is True:
23 | dtype = torch.cuda.FloatTensor
24 | else:
25 | dtype = torch.FloatTensor
26 |
27 | # Make interpolated z
28 | step = 1 / (img_num-1)
29 | alpha = torch.from_numpy(np.arange(0, 1, step))
30 | interpolated_z = torch.FloatTensor(n, img_num, z_dim).type(dtype)
31 |
32 | for i in range(n):
33 | first_z = torch.randn(1, z_dim)
34 | last_z = torch.randn(1, z_dim)
35 |
36 | for j in range(img_num-1):
37 | interpolated_z[i, j] = (1 - alpha[j]) * first_z + alpha[j] * last_z
38 | interpolated_z[i, img_num-1] = last_z
39 |
40 | return interpolated_z
41 |
42 | '''
43 | < make_z >
44 | Make latent code
45 |
46 | * Parameters
47 | n : Input images number
48 | img_num : Generated images number per one input image
49 | z_dim : Dimension of latent code. Basically 8.
50 | sample_type : random or interpolation
51 | '''
52 | def make_z(n, img_num, z_dim=8, sample_type='random'):
53 | if sample_type == 'random':
54 | z = util.var(torch.randn(n, img_num, 8))
55 | elif sample_type == 'interpolation':
56 | z = util.var(make_interpolation(n=n, img_num=img_num, z_dim=z_dim))
57 |
58 | return z
59 |
60 |
61 | '''
62 | < make_img >
63 | Generate images.
64 |
65 | * Parameters
66 | dloader : Dataloader
67 | G : Generator
68 | z : Random latent code with size of (N, img_num, z_dim)
69 | img_size : Image size. Now only 128 is available.
70 | img_num : Generated images number per one input image.
71 | '''
72 | def make_img(dloader, G, z, img_size=128):
73 | if torch.cuda.is_available():
74 | dtype = torch.cuda.FloatTensor
75 | else:
76 | dtype = torch.FloatTensor
77 |
78 | iter_dloader = iter(dloader)
79 | img, _ = iter_dloader.next()
80 | img_num = z.size(1)
81 |
82 | N = img.size(0)
83 | img = util.var(img.type(dtype))
84 |
85 | result_img = torch.FloatTensor(N * (img_num + 1), 3, img_size, img_size).type(dtype)
86 |
87 | for i in range(N):
88 | # The leftmost is domain A image(Edge image)
89 | result_img[i * (img_num + 1)] = img[i].data
90 |
91 | # Generate img_num images per a domain A image
92 | for j in range(img_num):
93 | img_ = img[i].unsqueeze(dim=0)
94 | z_ = z[i, j, :].unsqueeze(dim=0)
95 |
96 | out_img = G(img_, z_)
97 | result_img[i * (img_num + 1) + j + 1] = out_img.data
98 |
99 |
100 | result_img = result_img / 2 + 0.5
101 |
102 | return result_img
103 |
104 | def main(args):
105 | dloader, dlen = data_loader(root=args.root, batch_size='all', shuffle=False,
106 | img_size=128, mode='val')
107 |
108 | if torch.cuda.is_available() is True:
109 | dtype = torch.cuda.FloatTensor
110 | else:
111 | dtype = torch.FloatTensor
112 |
113 | if args.epoch is not None:
114 | weight_name = '{epoch}-G.pkl'.format(epoch=args.epoch)
115 | else:
116 | weight_name = 'G.pkl'
117 |
118 | weight_path = os.path.join(args.weight_dir, weight_name)
119 | G = model.Generator(z_dim=8).type(dtype)
120 | G.load_state_dict(torch.load(weight_path))
121 | G.eval()
122 |
123 | if os.path.exists(args.result_dir) is False:
124 | os.makedirs(args.result_dir)
125 |
126 | # For example, img_name = random_55.png
127 | if args.epoch is None:
128 | args.epoch = 'latest'
129 | img_name = '{type}_{epoch}.png'.format(type=args.sample_type, epoch=args.epoch)
130 | img_path = os.path.join(args.result_dir, img_name)
131 |
132 | # Make latent code and images
133 | z = make_z(n=dlen, img_num=args.img_num, z_dim=8, sample_type=args.sample_type)
134 |
135 | result_img = make_img(dloader, G, z, img_size=128)
136 | torchvision.utils.save_image(result_img, img_path, nrow=args.img_num + 1, padding=4)
137 |
138 | if __name__ == '__main__':
139 | parser = argparse.ArgumentParser()
140 | parser.add_argument('--sample_type', type=str, choices=['random', 'interpolation'], default='random',
141 | help='Type of sampling : \'random\' or \'interpolation\'')
142 | parser.add_argument('--root', type=str, default='data/edges2shoes',
143 | help='Data location')
144 | parser.add_argument('--result_dir', type=str, default='test',
145 | help='Ouput images location')
146 | parser.add_argument('--weight_dir', type=str, default='weight',
147 | help='Trained weight location of generator. pkl file location')
148 | parser.add_argument('--img_num', type=int, default=5,
149 | help='Generated images number per one input image')
150 | parser.add_argument('--epoch', type=int,
151 | help='Epoch that you want to see the result. If it is None, the most recent epoch')
152 |
153 | args = parser.parse_args()
154 | main(args)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import os
4 |
5 | from solver import Solver
6 |
7 | def main(args):
8 | solver = Solver(root = args.root,
9 | result_dir = args.result_dir,
10 | weight_dir = args.weight_dir,
11 | load_weight = args.load_weight,
12 | batch_size = args.batch_size,
13 | test_size = args.test_size,
14 | test_img_num = args.test_img_num,
15 | img_size = args.img_size,
16 | num_epoch = args.num_epoch,
17 | save_every = args.save_every,
18 | lr = args.lr,
19 | beta_1 = args.beta_1,
20 | beta_2 = args.beta_2,
21 | lambda_kl = args.lambda_kl,
22 | lambda_img = args.lambda_img,
23 | lambda_z = args.lambda_z,
24 | z_dim = args.z_dim)
25 |
26 | solver.train()
27 |
28 | if __name__ == '__main__':
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument('--root', type=str, default='data/edges2shoes',
31 | help='Data location')
32 | parser.add_argument('--result_dir', type=str, default='result_img',
33 | help='Result images location for intermediate check')
34 | parser.add_argument('--weight_dir', type=str, default='weight',
35 | help='Weight location')
36 | parser.add_argument('--batch_size', type=int, default=2,
37 | help='Training batch size')
38 | parser.add_argument('--test_size', type=int, default=20,
39 | help='Test batch size for intermediate check')
40 | parser.add_argument('--test_img_num', type=int, default=5,
41 | help='How many images do you want to generate for intermediate check?')
42 | parser.add_argument('--img_size', type=int, default=128,
43 | help='Image size')
44 | parser.add_argument('--lr', type=float, default=0.0002,
45 | help='Learning rate')
46 | parser.add_argument('--beta_1', type=float, default=0.5,
47 | help='Beta1 for Adam')
48 | parser.add_argument('--beta_2', type=float, default=0.999,
49 | help='Beta2 for Adam')
50 | parser.add_argument('--lambda_kl', type=float, default=0.01,
51 | help='Lambda for KL Divergence')
52 | parser.add_argument('--lambda_img', type=float, default=10,
53 | help='Lambda for image reconstruction')
54 | parser.add_argument('--lambda_z', type=float, default=0.5,
55 | help='Lambda for z reconstruction')
56 | parser.add_argument('--z_dim', type=int, default=8,
57 | help='Dimension of z')
58 | parser.add_argument('--num_epoch', type=int, default=100,
59 | help='Number of epoch')
60 | parser.add_argument('--save_every', type=int, default=1000,
61 | help='How often do you want to see the intermediate result?')
62 | parser.add_argument('--load_weight', action='store_true',
63 | help='Load weight or not')
64 |
65 | args = parser.parse_args()
66 | main(args)
67 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 |
4 | '''
5 | < var >
6 | Convert tensor to Variable
7 | '''
8 | def var(tensor, requires_grad=True):
9 | if torch.cuda.is_available():
10 | dtype = torch.cuda.FloatTensor
11 | else:
12 | dtype = torch.FloatTensor
13 |
14 | var = Variable(tensor.type(dtype), requires_grad=requires_grad)
15 |
16 | return var
17 |
18 | '''
19 | < make_img >
20 | Generate images
21 |
22 | * Parameters
23 | dloader : Data loader for test data set
24 | G : Generator
25 | z : random_z(size = (N, img_num, z_dim))
26 | N : test img number / img_num : Number of images that you want to generate with one test img / z_dim : 8
27 | img_num : Number of images that you want to generate with one test img
28 | '''
29 | def make_img(dloader, G, z, img_num=5, img_size=128):
30 | if torch.cuda.is_available():
31 | dtype = torch.cuda.FloatTensor
32 | else:
33 | dtype = torch.FloatTensor
34 |
35 | dloader = iter(dloader)
36 | img, _ = dloader.next()
37 |
38 | N = img.size(0)
39 | img = var(img.type(dtype))
40 |
41 | result_img = torch.FloatTensor(N * (img_num + 1), 3, img_size, img_size).type(dtype)
42 |
43 | for i in range(N):
44 | # original image to the leftmost
45 | result_img[i * (img_num + 1)] = img[i].data
46 |
47 | # Insert generated images to the next of the original image
48 | for j in range(img_num):
49 | img_ = img[i].unsqueeze(dim=0)
50 | z_ = z[i, j, :].unsqueeze(dim=0)
51 |
52 | out_img = G(img_, z_)
53 | result_img[i * (img_num + 1) + j + 1] = out_img.data
54 |
55 |
56 | # [-1, 1] -> [0, 1]
57 | result_img = result_img / 2 + 0.5
58 |
59 | return result_img
--------------------------------------------------------------------------------