├── LICENSE
├── README.md
├── assets
├── 1.png
├── 10.png
├── 3.png
├── 5.png
├── 6.png
├── 7.png
├── 8.png
├── 9.png
├── network_architecture.png
└── paper_results.png
├── data
└── README.md
├── networks.py
├── train.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Hyeonwoo Kang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-Conditional-image-to-image-translation
2 | Pytorch implementation of Conditional image-to-image translation [1] (CVPR 2018)
3 | * Parameters without information in the paper were set arbitrarily. (I could not find the supplementary document)
4 |
5 |
6 | ## Usage
7 | ```
8 | python train.py --dataset dataset
9 | ```
10 |
11 | ### Folder structure
12 | The following shows basic folder structure.
13 | ```
14 | ├── data
15 | ├── dataset # not included in this repo
16 | ├── trainA
17 | ├── aaa.png
18 | ├── bbb.jpg
19 | └── ...
20 | ├── trainB
21 | ├── ccc.png
22 | ├── ddd.jpg
23 | └── ...
24 | ├── testA
25 | ├── eee.png
26 | ├── fff.jpg
27 | └── ...
28 | └── testB
29 | ├── ggg.png
30 | ├── hhh.jpg
31 | └── ...
32 | ├── train.py # training code
33 | ├── utils.py
34 | ├── networks.py
35 | └── name_results # results to be saved here
36 | ```
37 |
38 | ## Resutls
39 | ### paper results
40 |
41 |
42 | ### celebA gender translation results (100 epoch)
43 |
44 |
45 | InputA - InputB - A2B - B2A (this repo) |
46 |
47 |
48 |
49 | |
50 |
51 |
52 | |
53 |
54 |
55 | |
56 |
57 |
58 | |
59 |
60 |
61 | |
62 |
63 |
64 | |
65 |
66 |
67 | |
68 |
69 |
70 | |
71 |
72 |
73 | ## Development Environment
74 | * NVIDIA GTX 1080 ti
75 | * cuda 8.0
76 | * python 3.5.3
77 | * pytorch 0.4.0
78 | * torchvision 0.2.1
79 |
80 | ## Reference
81 | [1] Lin, Jianxin, et al. "Conditional image-to-image translation." The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)(July 2018). 2018.
82 |
83 | (Full paper: http://openaccess.thecvf.com/content_cvpr_2018/papers/Lin_Conditional_Image-to-Image_Translation_CVPR_2018_paper.pdf)
84 |
--------------------------------------------------------------------------------
/assets/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/1.png
--------------------------------------------------------------------------------
/assets/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/10.png
--------------------------------------------------------------------------------
/assets/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/3.png
--------------------------------------------------------------------------------
/assets/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/5.png
--------------------------------------------------------------------------------
/assets/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/6.png
--------------------------------------------------------------------------------
/assets/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/7.png
--------------------------------------------------------------------------------
/assets/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/8.png
--------------------------------------------------------------------------------
/assets/9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/9.png
--------------------------------------------------------------------------------
/assets/network_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/network_architecture.png
--------------------------------------------------------------------------------
/assets/paper_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/paper_results.png
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Copy your data in this folder
2 |
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | import utils
2 | import torch.nn as nn
3 |
4 | class encoder(nn.Module):
5 | # initializers
6 | def __init__(self, in_nc, nf=32, img_size=64):
7 | super(encoder, self).__init__()
8 | self.input_nc = in_nc
9 | self.nf = nf
10 | self.img_size = img_size
11 | self.conv = nn.Sequential(
12 | nn.Conv2d(in_nc, nf, 4, 2, 1),
13 | nn.LeakyReLU(0.2, True),
14 | nn.Conv2d(nf, nf * 2, 4, 2, 1),
15 | nn.BatchNorm2d(nf * 2),
16 | nn.LeakyReLU(0.2, True),
17 | nn.Conv2d(nf * 2, nf * 4, 4, 2, 1),
18 | nn.BatchNorm2d(nf * 4),
19 | nn.LeakyReLU(0.2, True),
20 | )
21 | self.independent_feature = nn.Sequential(
22 | nn.Conv2d(nf * 4, nf * 8, 4, 2, 1),
23 | )
24 | self.specific_feature = nn.Sequential(
25 | nn.Linear((nf * 4) * (img_size // 8) * (img_size // 8), nf * 8),
26 | nn.LeakyReLU(0.2, True),
27 | nn.Linear(nf * 8, nf * 8),
28 | )
29 |
30 | utils.initialize_weights(self)
31 |
32 | # forward method
33 | def forward(self, input):
34 | x = self.conv(input)
35 | i = self.independent_feature(x)
36 | f = x.view(-1, (self.nf * 4) * (self.img_size // 8) * (self.img_size // 8))
37 | s = self.specific_feature(f)
38 | s = s.unsqueeze(2)
39 | s = s.unsqueeze(3)
40 |
41 | return i, s
42 |
43 |
44 | class decoder(nn.Module):
45 | # initializers
46 | def __init__(self, out_nc, nf=32):
47 | super(decoder, self).__init__()
48 | self.output_nc = out_nc
49 | self.nf = nf
50 | self.deconv = nn.Sequential(
51 | nn.ConvTranspose2d(nf * 8, nf * 4, 4, 2, 1),
52 | nn.ReLU(True),
53 | nn.ConvTranspose2d(nf * 4, nf * 2, 4, 2, 1),
54 | nn.BatchNorm2d(nf * 2),
55 | nn.ReLU(True),
56 | nn.ConvTranspose2d(nf * 2, nf, 4, 2, 1),
57 | nn.BatchNorm2d(nf),
58 | nn.ReLU(True),
59 | nn.ConvTranspose2d(nf, out_nc, 4, 2, 1),
60 | nn.Tanh(),
61 | )
62 |
63 | utils.initialize_weights(self)
64 |
65 | # forward method
66 | def forward(self, input):
67 | x = self.deconv(input)
68 |
69 | return x
70 |
71 |
72 | class discriminator(nn.Module):
73 | # initializers
74 | def __init__(self, in_nc, out_nc, nf=32, img_size=64):
75 | super(discriminator, self).__init__()
76 | self.input_nc = in_nc
77 | self.output_nc = out_nc
78 | self.nf = nf
79 | self.img_size = img_size
80 | self.conv = nn.Sequential(
81 | nn.Conv2d(in_nc, nf, 4, 2, 1),
82 | nn.LeakyReLU(0.2, True),
83 | nn.Conv2d(nf, nf * 2, 4, 2, 1),
84 | nn.BatchNorm2d(nf * 2),
85 | nn.LeakyReLU(0.2, True),
86 | nn.Conv2d(nf * 2, nf * 4, 4, 2, 1),
87 | nn.BatchNorm2d(nf * 4),
88 | nn.LeakyReLU(0.2, True),
89 | nn.Conv2d(nf * 4, nf * 8, 4, 2, 1),
90 | nn.BatchNorm2d(nf * 8),
91 | nn.LeakyReLU(0.2, True),
92 | )
93 | self.fc = nn.Sequential(
94 | nn.Linear((nf * 8) * (img_size // 16) * (img_size // 16), nf * 8),
95 | nn.LeakyReLU(0.2, True),
96 | nn.Linear(nf * 8, out_nc),
97 | nn.Sigmoid(),
98 | )
99 |
100 | utils.initialize_weights(self)
101 |
102 | # forward method
103 | def forward(self, input):
104 | x = self.conv(input)
105 | f = x.view(-1, (self.nf * 8) * (self.img_size // 16) * (self.img_size // 16))
106 | d = self.fc(f)
107 | d = d.unsqueeze(2)
108 | d = d.unsqueeze(3)
109 |
110 | return d
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os, time, pickle, argparse, networks, utils, itertools
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import matplotlib.pyplot as plt
6 | from torchvision import transforms
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--dataset', default='dataset', help='dataset')
10 | parser.add_argument('--in_ngc', type=int, default=3, help='input channel for generator')
11 | parser.add_argument('--out_ngc', type=int, default=3, help='output channel for generator')
12 | parser.add_argument('--in_ndc', type=int, default=3, help='input channel for discriminator')
13 | parser.add_argument('--out_ndc', type=int, default=1, help='output channel for discriminator')
14 | parser.add_argument('--batch_size', type=int, default=512, help='batch size')
15 | parser.add_argument('--ngf', type=int, default=64)
16 | parser.add_argument('--ndf', type=int, default=64)
17 | parser.add_argument('--nb', type=int, default=8, help='the number of resnet block layers for generator')
18 | parser.add_argument('--img_size', type=int, default=64, help='input image size')
19 | parser.add_argument('--train_epoch', type=int, default=100)
20 | parser.add_argument('--lrD', type=float, default=0.0002, help='learning rate, default=0.0002')
21 | parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate, default=0.0002')
22 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
23 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
24 | args = parser.parse_args()
25 |
26 | print('------------ Options -------------')
27 | for k, v in sorted(vars(args).items()):
28 | print('%s: %s' % (str(k), str(v)))
29 | print('-------------- End ----------------')
30 |
31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32 | if torch.backends.cudnn.enabled:
33 | torch.backends.cudnn.benchmark = True
34 |
35 | # results save path
36 | if not os.path.isdir(os.path.join(args.dataset_name + '_results', 'img')):
37 | os.makedirs(os.path.join(args.dataset_name + '_results', 'img'))
38 | if not os.path.isdir(os.path.join(args.dataset_name + '_results', 'model')):
39 | os.makedirs(os.path.join(args.dataset_name + '_results', 'model'))
40 |
41 | # data_loader
42 | transform = transforms.Compose([
43 | transforms.Resize((args.img_size, args.img_size)),
44 | transforms.ToTensor(),
45 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
46 | ])
47 | train_loader_A = utils.data_load(os.path.join('data', args.dataset), 'trainA', transform, args.batch_size, shuffle=True, drop_last=True)
48 | train_loader_B = utils.data_load(os.path.join('data', args.dataset), 'trainB', transform, args.batch_size, shuffle=True, drop_last=True)
49 | test_loader_A = utils.data_load(os.path.join('data', args.dataset), 'testA', transform, 1, shuffle=True, drop_last=True)
50 | test_loader_B = utils.data_load(os.path.join('data', args.dataset), 'testB', transform, 1, shuffle=True, drop_last=True)
51 |
52 | # network
53 | En_A = networks.encoder(in_nc=args.in_ngc, nf=args.ngf, img_size=args.img_size).to(device)
54 | En_B = networks.encoder(in_nc=args.in_ngc, nf=args.ngf, img_size=args.img_size).to(device)
55 | De_A = networks.decoder(out_nc=args.out_ngc, nf=args.ngf).to(device)
56 | De_B = networks.decoder(out_nc=args.out_ngc, nf=args.ngf).to(device)
57 | Disc_A = networks.discriminator(in_nc=args.in_ndc, out_nc=args.out_ndc, nf=args.ndf, img_size=args.img_size).to(device)
58 | Disc_B = networks.discriminator(in_nc=args.in_ndc, out_nc=args.out_ndc, nf=args.ndf, img_size=args.img_size).to(device)
59 | En_A.train()
60 | En_B.train()
61 | De_A.train()
62 | De_B.train()
63 | Disc_A.train()
64 | Disc_B.train()
65 | print('---------- Networks initialized -------------')
66 | utils.print_network(En_A)
67 | utils.print_network(En_B)
68 | utils.print_network(De_A)
69 | utils.print_network(De_B)
70 | utils.print_network(Disc_A)
71 | utils.print_network(Disc_B)
72 | print('-----------------------------------------------')
73 |
74 | # loss
75 | BCE_loss = nn.BCELoss().to(device)
76 | L1_loss = nn.L1Loss().to(device)
77 |
78 | # Adam optimizer
79 | Gen_optimizer = optim.Adam(itertools.chain(En_A.parameters(), De_A.parameters(), En_B.parameters(), De_B.parameters()), lr=args.lrG, betas=(args.beta1, args.beta2))
80 | Disc_A_optimizer = optim.Adam(Disc_A.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))
81 | Disc_B_optimizer = optim.Adam(Disc_B.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))
82 |
83 | train_hist = {}
84 | train_hist['Disc_A_loss'] = []
85 | train_hist['Disc_B_loss'] = []
86 | train_hist['Gen_loss'] = []
87 | train_hist['per_epoch_time'] = []
88 | train_hist['total_time'] = []
89 | print('training start!')
90 | start_time = time.time()
91 | real = torch.ones(args.batch_size, 1, 1, 1).to(device)
92 | fake = torch.zeros(args.batch_size, 1, 1, 1).to(device)
93 | for epoch in range(args.train_epoch):
94 | epoch_start_time = time.time()
95 | En_A.train()
96 | En_B.train()
97 | De_A.train()
98 | De_B.train()
99 | Disc_A_losses = []
100 | Disc_B_losses = []
101 | Gen_losses = []
102 | iter = 0
103 | for (A, _), (B, _) in zip(train_loader_A, train_loader_B):
104 | A, B = A.to(device), B.to(device)
105 |
106 | # train Disc_A & Disc_B
107 | # Disc real loss
108 | Disc_A_real = Disc_A(A)
109 | Disc_A_real_loss = BCE_loss(Disc_A_real, real)
110 |
111 | Disc_B_real = Disc_B(B)
112 | Disc_B_real_loss = BCE_loss(Disc_B_real, real)
113 |
114 | # Disc fake loss
115 | in_A, sp_A = En_A(A)
116 | in_B, sp_B = En_B(B)
117 |
118 | # De_A == B2A decoder, De_B == A2B decoder
119 | B2A = De_A(in_B + sp_A)
120 | A2B = De_B(in_A + sp_B)
121 |
122 | Disc_A_fake = Disc_A(B2A)
123 | Disc_A_fake_loss = BCE_loss(Disc_A_fake, fake)
124 |
125 | Disc_B_fake = Disc_B(A2B)
126 | Disc_B_fake_loss = BCE_loss(Disc_B_fake, fake)
127 |
128 | Disc_A_loss = Disc_A_real_loss + Disc_A_fake_loss
129 | Disc_B_loss = Disc_B_real_loss + Disc_B_fake_loss
130 |
131 | Disc_A_optimizer.zero_grad()
132 | Disc_A_loss.backward(retain_graph=True)
133 | Disc_A_optimizer.step()
134 |
135 | Disc_B_optimizer.zero_grad()
136 | Disc_B_loss.backward(retain_graph=True)
137 | Disc_B_optimizer.step()
138 |
139 | train_hist['Disc_A_loss'].append(Disc_A_loss.item())
140 | train_hist['Disc_B_loss'].append(Disc_B_loss.item())
141 | Disc_A_losses.append(Disc_A_loss.item())
142 | Disc_B_losses.append(Disc_B_loss.item())
143 |
144 | # train Gen
145 | # Gen adversarial loss
146 | in_A, sp_A = En_A(A)
147 | in_B, sp_B = En_B(B)
148 |
149 | B2A = De_A(in_B + sp_A)
150 | A2B = De_B(in_A + sp_B)
151 |
152 | Dist_A_fake = Disc_A(B2A)
153 | Gen_A_fake_loss = BCE_loss(Disc_A_fake, real)
154 |
155 | Disc_B_fake = Disc_B(A2B)
156 | Gen_B_fake_loss = BCE_loss(Disc_B_fake, real)
157 |
158 | # Gen Dual loss
159 | in_A_hat, sp_B_hat = En_B(A2B)
160 | in_B_hat, sp_A_hat = En_A(B2A)
161 |
162 | A_hat = De_A(in_A_hat + sp_A)
163 | B_hat = De_B(in_B_hat + sp_B)
164 |
165 | Gen_gan_loss = Gen_A_fake_loss + Gen_B_fake_loss
166 | Gen_dual_loss = L1_loss(A_hat, A.detach()) ** 2 + L1_loss(B_hat, B.detach()) ** 2
167 | Gen_in_loss = L1_loss(in_A_hat, in_A.detach()) ** 2 + L1_loss(in_B_hat, in_B.detach()) ** 2
168 | Gen_sp_loss = L1_loss(sp_A_hat, sp_A.detach()) ** 2 + L1_loss(sp_B_hat, sp_B.detach()) ** 2
169 |
170 | Gen_loss = Gen_A_fake_loss + Gen_B_fake_loss + Gen_dual_loss + Gen_in_loss + Gen_sp_loss
171 |
172 | Gen_optimizer.zero_grad()
173 | Gen_loss.backward()
174 | Gen_optimizer.step()
175 |
176 | train_hist['Gen_loss'].append(Gen_loss.item())
177 | Gen_losses.append(Gen_loss.item())
178 |
179 | iter += 1
180 |
181 | per_epoch_time = time.time() - epoch_start_time
182 | train_hist['per_epoch_time'].append(per_epoch_time)
183 | print(
184 | '[%d/%d] - time: %.2f, Disc A loss: %.3f, Disc B loss: %.3f, Gen loss: %.3f' % (
185 | (epoch + 1), args.train_epoch, per_epoch_time, torch.mean(torch.FloatTensor(Disc_A_losses)),
186 | torch.mean(torch.FloatTensor(Disc_B_losses)), torch.mean(torch.FloatTensor(Gen_losses)),))
187 |
188 |
189 | with torch.no_grad():
190 | En_A.eval()
191 | En_B.eval()
192 | De_A.eval()
193 | De_B.eval()
194 | n = 0
195 | for (A, _), (B, _) in zip(test_loader_A, test_loader_B):
196 | A, B = A.to(device), B.to(device)
197 |
198 | in_A, sp_A = En_A(A)
199 | in_B, sp_B = En_B(B)
200 |
201 | B2A = De_A(in_B + sp_A)
202 | A2B = De_B(in_A + sp_B)
203 |
204 | result = torch.cat((A[0], B[0], A2B[0], B2A[0]), 2)
205 | path = os.path.join(args.dataset_name + '_results', 'img', str(epoch+1) + '_epoch_' + args.dataset_name + '_' + str(n + 1) + '.png')
206 | plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
207 | n += 1
208 |
209 | torch.save(En_A.state_dict(), os.path.join(args.dataset_name + '_results', 'model', 'En_A_param_latest.pkl'))
210 | torch.save(En_B.state_dict(), os.path.join(args.dataset_name + '_results', 'model', 'En_B_param_latest.pkl'))
211 | torch.save(De_A.state_dict(), os.path.join(args.dataset_name + '_results', 'model', 'De_A_param_latest.pkl'))
212 | torch.save(De_B.state_dict(), os.path.join(args.dataset_name + '_results', 'model', 'De_B_param_latest.pkl'))
213 | torch.save(Disc_A.state_dict(), os.path.join(args.dataset_name + '_results', 'model', 'Disc_A_param_latest.pkl'))
214 | torch.save(Disc_B.state_dict(), os.path.join(args.dataset_name + '_results', 'model', 'Disc_B_param_latest.pkl'))
215 |
216 |
217 | if (epoch+1) % 50 == 0:
218 | torch.save(En_A.state_dict(),
219 | os.path.join(args.dataset_name + '_results', 'model', 'En_A_param_' + str(epoch+1) + '.pkl'))
220 | torch.save(En_B.state_dict(),
221 | os.path.join(args.dataset_name + '_results', 'model', 'En_B_param_' + str(epoch+1) + '.pkl'))
222 | torch.save(De_A.state_dict(),
223 | os.path.join(args.dataset_name + '_results', 'model', 'De_A_param_' + str(epoch+1) + '.pkl'))
224 | torch.save(De_B.state_dict(),
225 | os.path.join(args.dataset_name + '_results', 'model', 'De_B_param_' + str(epoch+1) + '.pkl'))
226 | torch.save(Disc_A.state_dict(),
227 | os.path.join(args.dataset_name + '_results', 'model', 'Disc_A_param_' + str(epoch+1) + '.pkl'))
228 | torch.save(Disc_B.state_dict(),
229 | os.path.join(args.dataset_name + '_results', 'model', 'Disc_B_param_' + str(epoch+1) + '.pkl'))
230 |
231 | total_time = time.time() - start_time
232 | train_hist['total_time'].append(total_time)
233 |
234 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_time'])), args.train_epoch, total_time))
235 | print("Training finish!... save training results")
236 | with open(os.path.join(args.dataset_name + '_results', 'train_hist.pkl'), 'wb') as f:
237 | pickle.dump(train_hist, f)
238 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import datasets
4 |
5 | def data_load(path, subfolder, transform, batch_size, shuffle=False, drop_last=False):
6 | dset = datasets.ImageFolder(path, transform)
7 | ind = dset.class_to_idx[subfolder]
8 |
9 | n = 0
10 | for i in range(dset.__len__()):
11 | if ind != dset.imgs[n][1]:
12 | del dset.imgs[n]
13 | n -= 1
14 |
15 | n += 1
16 |
17 | return torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
18 |
19 | def print_network(net):
20 | num_params = 0
21 | for param in net.parameters():
22 | num_params += param.numel()
23 | print(net)
24 | print('Total number of parameters: %d' % num_params)
25 |
26 | def initialize_weights(net):
27 | for m in net.modules():
28 | if isinstance(m, nn.Conv2d):
29 | m.weight.data.normal_(0, 0.02)
30 | m.bias.data.zero_()
31 | elif isinstance(m, nn.ConvTranspose2d):
32 | m.weight.data.normal_(0, 0.02)
33 | m.bias.data.zero_()
34 | elif isinstance(m, nn.Linear):
35 | m.weight.data.normal_(0, 0.02)
36 | m.bias.data.zero_()
37 | elif isinstance(m, nn.BatchNorm2d):
38 | m.weight.data.fill_(1)
39 | m.bias.data.zero_()
--------------------------------------------------------------------------------