├── LICENSE ├── README.md ├── assets ├── 1.png ├── 10.png ├── 3.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png ├── 9.png ├── network_architecture.png └── paper_results.png ├── data └── README.md ├── networks.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Hyeonwoo Kang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-Conditional-image-to-image-translation 2 | Pytorch implementation of Conditional image-to-image translation [1] (CVPR 2018) 3 | * Parameters without information in the paper were set arbitrarily. (I could not find the supplementary document) 4 | 5 | 6 | ## Usage 7 | ``` 8 | python train.py --dataset dataset 9 | ``` 10 | 11 | ### Folder structure 12 | The following shows basic folder structure. 13 | ``` 14 | ├── data 15 | ├── dataset # not included in this repo 16 | ├── trainA 17 | ├── aaa.png 18 | ├── bbb.jpg 19 | └── ... 20 | ├── trainB 21 | ├── ccc.png 22 | ├── ddd.jpg 23 | └── ... 24 | ├── testA 25 | ├── eee.png 26 | ├── fff.jpg 27 | └── ... 28 | └── testB 29 | ├── ggg.png 30 | ├── hhh.jpg 31 | └── ... 32 | ├── train.py # training code 33 | ├── utils.py 34 | ├── networks.py 35 | └── name_results # results to be saved here 36 | ``` 37 | 38 | ## Resutls 39 | ### paper results 40 | 41 | 42 | ### celebA gender translation results (100 epoch) 43 | 44 | 45 | 46 | 47 | 48 | 50 | 51 | 53 | 54 | 56 | 57 | 59 | 60 | 62 | 63 | 65 | 66 | 68 | 69 | 71 |
InputA - InputB - A2B - B2A (this repo)
49 |
52 |
55 |
58 |
61 |
64 |
67 |
70 |
72 | 73 | ## Development Environment 74 | * NVIDIA GTX 1080 ti 75 | * cuda 8.0 76 | * python 3.5.3 77 | * pytorch 0.4.0 78 | * torchvision 0.2.1 79 | 80 | ## Reference 81 | [1] Lin, Jianxin, et al. "Conditional image-to-image translation." The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)(July 2018). 2018. 82 | 83 | (Full paper: http://openaccess.thecvf.com/content_cvpr_2018/papers/Lin_Conditional_Image-to-Image_Translation_CVPR_2018_paper.pdf) 84 | -------------------------------------------------------------------------------- /assets/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/1.png -------------------------------------------------------------------------------- /assets/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/10.png -------------------------------------------------------------------------------- /assets/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/3.png -------------------------------------------------------------------------------- /assets/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/5.png -------------------------------------------------------------------------------- /assets/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/6.png -------------------------------------------------------------------------------- /assets/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/7.png -------------------------------------------------------------------------------- /assets/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/8.png -------------------------------------------------------------------------------- /assets/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/9.png -------------------------------------------------------------------------------- /assets/network_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/network_architecture.png -------------------------------------------------------------------------------- /assets/paper_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/znxlwm/pytorch-Conditional-image-to-image-translation/cdcb2e6a76a507167e174647beaf2992d3d89d95/assets/paper_results.png -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Copy your data in this folder 2 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import torch.nn as nn 3 | 4 | class encoder(nn.Module): 5 | # initializers 6 | def __init__(self, in_nc, nf=32, img_size=64): 7 | super(encoder, self).__init__() 8 | self.input_nc = in_nc 9 | self.nf = nf 10 | self.img_size = img_size 11 | self.conv = nn.Sequential( 12 | nn.Conv2d(in_nc, nf, 4, 2, 1), 13 | nn.LeakyReLU(0.2, True), 14 | nn.Conv2d(nf, nf * 2, 4, 2, 1), 15 | nn.BatchNorm2d(nf * 2), 16 | nn.LeakyReLU(0.2, True), 17 | nn.Conv2d(nf * 2, nf * 4, 4, 2, 1), 18 | nn.BatchNorm2d(nf * 4), 19 | nn.LeakyReLU(0.2, True), 20 | ) 21 | self.independent_feature = nn.Sequential( 22 | nn.Conv2d(nf * 4, nf * 8, 4, 2, 1), 23 | ) 24 | self.specific_feature = nn.Sequential( 25 | nn.Linear((nf * 4) * (img_size // 8) * (img_size // 8), nf * 8), 26 | nn.LeakyReLU(0.2, True), 27 | nn.Linear(nf * 8, nf * 8), 28 | ) 29 | 30 | utils.initialize_weights(self) 31 | 32 | # forward method 33 | def forward(self, input): 34 | x = self.conv(input) 35 | i = self.independent_feature(x) 36 | f = x.view(-1, (self.nf * 4) * (self.img_size // 8) * (self.img_size // 8)) 37 | s = self.specific_feature(f) 38 | s = s.unsqueeze(2) 39 | s = s.unsqueeze(3) 40 | 41 | return i, s 42 | 43 | 44 | class decoder(nn.Module): 45 | # initializers 46 | def __init__(self, out_nc, nf=32): 47 | super(decoder, self).__init__() 48 | self.output_nc = out_nc 49 | self.nf = nf 50 | self.deconv = nn.Sequential( 51 | nn.ConvTranspose2d(nf * 8, nf * 4, 4, 2, 1), 52 | nn.ReLU(True), 53 | nn.ConvTranspose2d(nf * 4, nf * 2, 4, 2, 1), 54 | nn.BatchNorm2d(nf * 2), 55 | nn.ReLU(True), 56 | nn.ConvTranspose2d(nf * 2, nf, 4, 2, 1), 57 | nn.BatchNorm2d(nf), 58 | nn.ReLU(True), 59 | nn.ConvTranspose2d(nf, out_nc, 4, 2, 1), 60 | nn.Tanh(), 61 | ) 62 | 63 | utils.initialize_weights(self) 64 | 65 | # forward method 66 | def forward(self, input): 67 | x = self.deconv(input) 68 | 69 | return x 70 | 71 | 72 | class discriminator(nn.Module): 73 | # initializers 74 | def __init__(self, in_nc, out_nc, nf=32, img_size=64): 75 | super(discriminator, self).__init__() 76 | self.input_nc = in_nc 77 | self.output_nc = out_nc 78 | self.nf = nf 79 | self.img_size = img_size 80 | self.conv = nn.Sequential( 81 | nn.Conv2d(in_nc, nf, 4, 2, 1), 82 | nn.LeakyReLU(0.2, True), 83 | nn.Conv2d(nf, nf * 2, 4, 2, 1), 84 | nn.BatchNorm2d(nf * 2), 85 | nn.LeakyReLU(0.2, True), 86 | nn.Conv2d(nf * 2, nf * 4, 4, 2, 1), 87 | nn.BatchNorm2d(nf * 4), 88 | nn.LeakyReLU(0.2, True), 89 | nn.Conv2d(nf * 4, nf * 8, 4, 2, 1), 90 | nn.BatchNorm2d(nf * 8), 91 | nn.LeakyReLU(0.2, True), 92 | ) 93 | self.fc = nn.Sequential( 94 | nn.Linear((nf * 8) * (img_size // 16) * (img_size // 16), nf * 8), 95 | nn.LeakyReLU(0.2, True), 96 | nn.Linear(nf * 8, out_nc), 97 | nn.Sigmoid(), 98 | ) 99 | 100 | utils.initialize_weights(self) 101 | 102 | # forward method 103 | def forward(self, input): 104 | x = self.conv(input) 105 | f = x.view(-1, (self.nf * 8) * (self.img_size // 16) * (self.img_size // 16)) 106 | d = self.fc(f) 107 | d = d.unsqueeze(2) 108 | d = d.unsqueeze(3) 109 | 110 | return d -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, time, pickle, argparse, networks, utils, itertools 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import matplotlib.pyplot as plt 6 | from torchvision import transforms 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--dataset', default='dataset', help='dataset') 10 | parser.add_argument('--in_ngc', type=int, default=3, help='input channel for generator') 11 | parser.add_argument('--out_ngc', type=int, default=3, help='output channel for generator') 12 | parser.add_argument('--in_ndc', type=int, default=3, help='input channel for discriminator') 13 | parser.add_argument('--out_ndc', type=int, default=1, help='output channel for discriminator') 14 | parser.add_argument('--batch_size', type=int, default=512, help='batch size') 15 | parser.add_argument('--ngf', type=int, default=64) 16 | parser.add_argument('--ndf', type=int, default=64) 17 | parser.add_argument('--nb', type=int, default=8, help='the number of resnet block layers for generator') 18 | parser.add_argument('--img_size', type=int, default=64, help='input image size') 19 | parser.add_argument('--train_epoch', type=int, default=100) 20 | parser.add_argument('--lrD', type=float, default=0.0002, help='learning rate, default=0.0002') 21 | parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate, default=0.0002') 22 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer') 23 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer') 24 | args = parser.parse_args() 25 | 26 | print('------------ Options -------------') 27 | for k, v in sorted(vars(args).items()): 28 | print('%s: %s' % (str(k), str(v))) 29 | print('-------------- End ----------------') 30 | 31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | if torch.backends.cudnn.enabled: 33 | torch.backends.cudnn.benchmark = True 34 | 35 | # results save path 36 | if not os.path.isdir(os.path.join(args.dataset_name + '_results', 'img')): 37 | os.makedirs(os.path.join(args.dataset_name + '_results', 'img')) 38 | if not os.path.isdir(os.path.join(args.dataset_name + '_results', 'model')): 39 | os.makedirs(os.path.join(args.dataset_name + '_results', 'model')) 40 | 41 | # data_loader 42 | transform = transforms.Compose([ 43 | transforms.Resize((args.img_size, args.img_size)), 44 | transforms.ToTensor(), 45 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 46 | ]) 47 | train_loader_A = utils.data_load(os.path.join('data', args.dataset), 'trainA', transform, args.batch_size, shuffle=True, drop_last=True) 48 | train_loader_B = utils.data_load(os.path.join('data', args.dataset), 'trainB', transform, args.batch_size, shuffle=True, drop_last=True) 49 | test_loader_A = utils.data_load(os.path.join('data', args.dataset), 'testA', transform, 1, shuffle=True, drop_last=True) 50 | test_loader_B = utils.data_load(os.path.join('data', args.dataset), 'testB', transform, 1, shuffle=True, drop_last=True) 51 | 52 | # network 53 | En_A = networks.encoder(in_nc=args.in_ngc, nf=args.ngf, img_size=args.img_size).to(device) 54 | En_B = networks.encoder(in_nc=args.in_ngc, nf=args.ngf, img_size=args.img_size).to(device) 55 | De_A = networks.decoder(out_nc=args.out_ngc, nf=args.ngf).to(device) 56 | De_B = networks.decoder(out_nc=args.out_ngc, nf=args.ngf).to(device) 57 | Disc_A = networks.discriminator(in_nc=args.in_ndc, out_nc=args.out_ndc, nf=args.ndf, img_size=args.img_size).to(device) 58 | Disc_B = networks.discriminator(in_nc=args.in_ndc, out_nc=args.out_ndc, nf=args.ndf, img_size=args.img_size).to(device) 59 | En_A.train() 60 | En_B.train() 61 | De_A.train() 62 | De_B.train() 63 | Disc_A.train() 64 | Disc_B.train() 65 | print('---------- Networks initialized -------------') 66 | utils.print_network(En_A) 67 | utils.print_network(En_B) 68 | utils.print_network(De_A) 69 | utils.print_network(De_B) 70 | utils.print_network(Disc_A) 71 | utils.print_network(Disc_B) 72 | print('-----------------------------------------------') 73 | 74 | # loss 75 | BCE_loss = nn.BCELoss().to(device) 76 | L1_loss = nn.L1Loss().to(device) 77 | 78 | # Adam optimizer 79 | Gen_optimizer = optim.Adam(itertools.chain(En_A.parameters(), De_A.parameters(), En_B.parameters(), De_B.parameters()), lr=args.lrG, betas=(args.beta1, args.beta2)) 80 | Disc_A_optimizer = optim.Adam(Disc_A.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) 81 | Disc_B_optimizer = optim.Adam(Disc_B.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) 82 | 83 | train_hist = {} 84 | train_hist['Disc_A_loss'] = [] 85 | train_hist['Disc_B_loss'] = [] 86 | train_hist['Gen_loss'] = [] 87 | train_hist['per_epoch_time'] = [] 88 | train_hist['total_time'] = [] 89 | print('training start!') 90 | start_time = time.time() 91 | real = torch.ones(args.batch_size, 1, 1, 1).to(device) 92 | fake = torch.zeros(args.batch_size, 1, 1, 1).to(device) 93 | for epoch in range(args.train_epoch): 94 | epoch_start_time = time.time() 95 | En_A.train() 96 | En_B.train() 97 | De_A.train() 98 | De_B.train() 99 | Disc_A_losses = [] 100 | Disc_B_losses = [] 101 | Gen_losses = [] 102 | iter = 0 103 | for (A, _), (B, _) in zip(train_loader_A, train_loader_B): 104 | A, B = A.to(device), B.to(device) 105 | 106 | # train Disc_A & Disc_B 107 | # Disc real loss 108 | Disc_A_real = Disc_A(A) 109 | Disc_A_real_loss = BCE_loss(Disc_A_real, real) 110 | 111 | Disc_B_real = Disc_B(B) 112 | Disc_B_real_loss = BCE_loss(Disc_B_real, real) 113 | 114 | # Disc fake loss 115 | in_A, sp_A = En_A(A) 116 | in_B, sp_B = En_B(B) 117 | 118 | # De_A == B2A decoder, De_B == A2B decoder 119 | B2A = De_A(in_B + sp_A) 120 | A2B = De_B(in_A + sp_B) 121 | 122 | Disc_A_fake = Disc_A(B2A) 123 | Disc_A_fake_loss = BCE_loss(Disc_A_fake, fake) 124 | 125 | Disc_B_fake = Disc_B(A2B) 126 | Disc_B_fake_loss = BCE_loss(Disc_B_fake, fake) 127 | 128 | Disc_A_loss = Disc_A_real_loss + Disc_A_fake_loss 129 | Disc_B_loss = Disc_B_real_loss + Disc_B_fake_loss 130 | 131 | Disc_A_optimizer.zero_grad() 132 | Disc_A_loss.backward(retain_graph=True) 133 | Disc_A_optimizer.step() 134 | 135 | Disc_B_optimizer.zero_grad() 136 | Disc_B_loss.backward(retain_graph=True) 137 | Disc_B_optimizer.step() 138 | 139 | train_hist['Disc_A_loss'].append(Disc_A_loss.item()) 140 | train_hist['Disc_B_loss'].append(Disc_B_loss.item()) 141 | Disc_A_losses.append(Disc_A_loss.item()) 142 | Disc_B_losses.append(Disc_B_loss.item()) 143 | 144 | # train Gen 145 | # Gen adversarial loss 146 | in_A, sp_A = En_A(A) 147 | in_B, sp_B = En_B(B) 148 | 149 | B2A = De_A(in_B + sp_A) 150 | A2B = De_B(in_A + sp_B) 151 | 152 | Dist_A_fake = Disc_A(B2A) 153 | Gen_A_fake_loss = BCE_loss(Disc_A_fake, real) 154 | 155 | Disc_B_fake = Disc_B(A2B) 156 | Gen_B_fake_loss = BCE_loss(Disc_B_fake, real) 157 | 158 | # Gen Dual loss 159 | in_A_hat, sp_B_hat = En_B(A2B) 160 | in_B_hat, sp_A_hat = En_A(B2A) 161 | 162 | A_hat = De_A(in_A_hat + sp_A) 163 | B_hat = De_B(in_B_hat + sp_B) 164 | 165 | Gen_gan_loss = Gen_A_fake_loss + Gen_B_fake_loss 166 | Gen_dual_loss = L1_loss(A_hat, A.detach()) ** 2 + L1_loss(B_hat, B.detach()) ** 2 167 | Gen_in_loss = L1_loss(in_A_hat, in_A.detach()) ** 2 + L1_loss(in_B_hat, in_B.detach()) ** 2 168 | Gen_sp_loss = L1_loss(sp_A_hat, sp_A.detach()) ** 2 + L1_loss(sp_B_hat, sp_B.detach()) ** 2 169 | 170 | Gen_loss = Gen_A_fake_loss + Gen_B_fake_loss + Gen_dual_loss + Gen_in_loss + Gen_sp_loss 171 | 172 | Gen_optimizer.zero_grad() 173 | Gen_loss.backward() 174 | Gen_optimizer.step() 175 | 176 | train_hist['Gen_loss'].append(Gen_loss.item()) 177 | Gen_losses.append(Gen_loss.item()) 178 | 179 | iter += 1 180 | 181 | per_epoch_time = time.time() - epoch_start_time 182 | train_hist['per_epoch_time'].append(per_epoch_time) 183 | print( 184 | '[%d/%d] - time: %.2f, Disc A loss: %.3f, Disc B loss: %.3f, Gen loss: %.3f' % ( 185 | (epoch + 1), args.train_epoch, per_epoch_time, torch.mean(torch.FloatTensor(Disc_A_losses)), 186 | torch.mean(torch.FloatTensor(Disc_B_losses)), torch.mean(torch.FloatTensor(Gen_losses)),)) 187 | 188 | 189 | with torch.no_grad(): 190 | En_A.eval() 191 | En_B.eval() 192 | De_A.eval() 193 | De_B.eval() 194 | n = 0 195 | for (A, _), (B, _) in zip(test_loader_A, test_loader_B): 196 | A, B = A.to(device), B.to(device) 197 | 198 | in_A, sp_A = En_A(A) 199 | in_B, sp_B = En_B(B) 200 | 201 | B2A = De_A(in_B + sp_A) 202 | A2B = De_B(in_A + sp_B) 203 | 204 | result = torch.cat((A[0], B[0], A2B[0], B2A[0]), 2) 205 | path = os.path.join(args.dataset_name + '_results', 'img', str(epoch+1) + '_epoch_' + args.dataset_name + '_' + str(n + 1) + '.png') 206 | plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2) 207 | n += 1 208 | 209 | torch.save(En_A.state_dict(), os.path.join(args.dataset_name + '_results', 'model', 'En_A_param_latest.pkl')) 210 | torch.save(En_B.state_dict(), os.path.join(args.dataset_name + '_results', 'model', 'En_B_param_latest.pkl')) 211 | torch.save(De_A.state_dict(), os.path.join(args.dataset_name + '_results', 'model', 'De_A_param_latest.pkl')) 212 | torch.save(De_B.state_dict(), os.path.join(args.dataset_name + '_results', 'model', 'De_B_param_latest.pkl')) 213 | torch.save(Disc_A.state_dict(), os.path.join(args.dataset_name + '_results', 'model', 'Disc_A_param_latest.pkl')) 214 | torch.save(Disc_B.state_dict(), os.path.join(args.dataset_name + '_results', 'model', 'Disc_B_param_latest.pkl')) 215 | 216 | 217 | if (epoch+1) % 50 == 0: 218 | torch.save(En_A.state_dict(), 219 | os.path.join(args.dataset_name + '_results', 'model', 'En_A_param_' + str(epoch+1) + '.pkl')) 220 | torch.save(En_B.state_dict(), 221 | os.path.join(args.dataset_name + '_results', 'model', 'En_B_param_' + str(epoch+1) + '.pkl')) 222 | torch.save(De_A.state_dict(), 223 | os.path.join(args.dataset_name + '_results', 'model', 'De_A_param_' + str(epoch+1) + '.pkl')) 224 | torch.save(De_B.state_dict(), 225 | os.path.join(args.dataset_name + '_results', 'model', 'De_B_param_' + str(epoch+1) + '.pkl')) 226 | torch.save(Disc_A.state_dict(), 227 | os.path.join(args.dataset_name + '_results', 'model', 'Disc_A_param_' + str(epoch+1) + '.pkl')) 228 | torch.save(Disc_B.state_dict(), 229 | os.path.join(args.dataset_name + '_results', 'model', 'Disc_B_param_' + str(epoch+1) + '.pkl')) 230 | 231 | total_time = time.time() - start_time 232 | train_hist['total_time'].append(total_time) 233 | 234 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_time'])), args.train_epoch, total_time)) 235 | print("Training finish!... save training results") 236 | with open(os.path.join(args.dataset_name + '_results', 'train_hist.pkl'), 'wb') as f: 237 | pickle.dump(train_hist, f) 238 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import datasets 4 | 5 | def data_load(path, subfolder, transform, batch_size, shuffle=False, drop_last=False): 6 | dset = datasets.ImageFolder(path, transform) 7 | ind = dset.class_to_idx[subfolder] 8 | 9 | n = 0 10 | for i in range(dset.__len__()): 11 | if ind != dset.imgs[n][1]: 12 | del dset.imgs[n] 13 | n -= 1 14 | 15 | n += 1 16 | 17 | return torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) 18 | 19 | def print_network(net): 20 | num_params = 0 21 | for param in net.parameters(): 22 | num_params += param.numel() 23 | print(net) 24 | print('Total number of parameters: %d' % num_params) 25 | 26 | def initialize_weights(net): 27 | for m in net.modules(): 28 | if isinstance(m, nn.Conv2d): 29 | m.weight.data.normal_(0, 0.02) 30 | m.bias.data.zero_() 31 | elif isinstance(m, nn.ConvTranspose2d): 32 | m.weight.data.normal_(0, 0.02) 33 | m.bias.data.zero_() 34 | elif isinstance(m, nn.Linear): 35 | m.weight.data.normal_(0, 0.02) 36 | m.bias.data.zero_() 37 | elif isinstance(m, nn.BatchNorm2d): 38 | m.weight.data.fill_(1) 39 | m.bias.data.zero_() --------------------------------------------------------------------------------