├── README.md ├── utils.py ├── resnet.py ├── network.py └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # VGrow in PyTorch 2 | PyTorch implementation of [Deep Generative Learning via Variational Gradient Flow](https://arxiv.org/abs/1901.08469). 3 | 4 | # Prerequisites 5 | Python 3.5+ 6 | PyTorch v0.4.1 7 | 8 | # Usage 9 | To run VGrow on [MNIST](http://yann.lecun.com/exdb/mnist/), [FashionMNIST](https://github.com/zalandoresearch/fashion-mnist) 10 | and [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html), use the following `cmd` with default arguments 11 | `python main.py --divergence KL --dataset --dataroot ` 12 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.dataset import Dataset 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class poolSet(Dataset): 7 | 8 | def __init__(self, p_z, p_img): 9 | self.len = len(p_z) 10 | self.z_data = p_z 11 | self.img_data = p_img 12 | 13 | def __getitem__(self, index): 14 | return self.z_data[index], self.img_data[index] 15 | 16 | def __len__(self): 17 | return self.len 18 | 19 | def inceptionScore(net, netG, device, nz, nclass, batchSize=250, eps=1e-6): 20 | 21 | net.to(device) 22 | netG.to(device) 23 | net.eval() 24 | netG.eval() 25 | 26 | pyx = np.zeros((batchSize*200, nclass)) 27 | 28 | for i in range(200): 29 | 30 | eval_z_b = torch.randn(batchSize, nz).to(device) 31 | fake_img_b = netG(eval_z_b) 32 | pyx[i*batchSize: (i+1)*batchSize] = F.softmax(net(fake_img_b).detach(), dim=1).cpu().numpy() 33 | 34 | py = np.mean(pyx, axis=0) 35 | 36 | kl = np.sum(pyx * (np.log(pyx+eps) - np.log(py+eps)), axis=1) 37 | kl = kl.mean() 38 | 39 | return np.exp(kl) 40 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activation ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class PreActBlock(nn.Module): 13 | '''Pre-activation version of the BasicBlock.''' 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | '''Pre-activation version of the original Bottleneck module.''' 39 | expansion = 4 40 | 41 | def __init__(self, in_planes, planes, stride=1): 42 | super(PreActBottleneck, self).__init__() 43 | self.bn1 = nn.BatchNorm2d(in_planes) 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 53 | ) 54 | 55 | def forward(self, x): 56 | out = F.relu(self.bn1(x)) 57 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 58 | out = self.conv1(out) 59 | out = self.conv2(F.relu(self.bn2(out))) 60 | out = self.conv3(F.relu(self.bn3(out))) 61 | out += shortcut 62 | return out 63 | 64 | 65 | class PreActResNet(nn.Module): 66 | def __init__(self, block, num_blocks, nc, num_classes=10): 67 | super(PreActResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(nc, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 75 | self.linear = nn.Linear(512*block.expansion, num_classes) 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride): 78 | strides = [stride] + [1]*(num_blocks-1) 79 | layers = [] 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, stride)) 82 | self.in_planes = planes * block.expansion 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | out = self.conv1(x) 87 | out = self.layer1(out) 88 | out = self.layer2(out) 89 | out = self.layer3(out) 90 | out = self.layer4(out) 91 | out = F.avg_pool2d(out, 4) 92 | out = out.view(out.size(0), -1) 93 | out = self.linear(out) 94 | return out 95 | 96 | 97 | def PreActResNet18(nc): 98 | return PreActResNet(PreActBlock, [2,2,2,2], nc) 99 | 100 | def PreActResNet34(nc): 101 | return PreActResNet(PreActBlock, [3,4,6,3], nc) 102 | 103 | def PreActResNet50(nc): 104 | return PreActResNet(PreActBottleneck, [3,4,6,3], nc) 105 | 106 | def PreActResNet101(nc): 107 | return PreActResNet(PreActBottleneck, [3,4,23,3], nc) 108 | 109 | def PreActResNet152(nc): 110 | return PreActResNet(PreActBottleneck, [3,8,36,3], nc) 111 | 112 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.utils.spectral_norm as sn 5 | 6 | #------------------------- ResNet ---------------------------- 7 | class ResBlock_G(nn.Module): 8 | def __init__(self, nf, up=False): 9 | super(ResBlock_G, self).__init__() 10 | 11 | self.nf = nf 12 | self.up = up 13 | 14 | self.SubBlock1 = nn.Sequential( 15 | nn.ReLU(True) 16 | ) 17 | 18 | self.SubBlock2 = nn.Sequential( 19 | sn(nn.Conv2d(nf, nf, 3, 1, 1, bias=False), n_power_iterations=5), 20 | nn.BatchNorm2d(nf), 21 | nn.ReLU(True), 22 | sn(nn.Conv2d(nf, nf, 3, 1, 1, bias=False), n_power_iterations=5) 23 | ) 24 | 25 | self.conv_shortcut = sn(nn.Conv2d(nf, nf, 1, 1, 0, bias=False), n_power_iterations=5) 26 | 27 | def forward(self, x): 28 | out = self.SubBlock1(x) 29 | 30 | if self.up: 31 | out = F.interpolate(out, scale_factor=2) 32 | shortcut = F.interpolate(x, scale_factor=2) 33 | shortcut = self.conv_shortcut(shortcut) 34 | 35 | else: 36 | shortcut = x 37 | 38 | out = self.SubBlock2(out) 39 | out += shortcut 40 | 41 | return out 42 | 43 | 44 | class G_resnet(nn.Module): 45 | def __init__(self, nc=3, ngf=128, nz=128): 46 | super(G_resnet, self).__init__() 47 | 48 | self.nc = nc 49 | self.ngf = ngf 50 | self.nz = nz 51 | 52 | self.linear = sn(nn.Linear(nz, 16*ngf), n_power_iterations=5) 53 | self.block1 = ResBlock_G(ngf, True) 54 | self.block2 = ResBlock_G(ngf, True) 55 | self.block3 = ResBlock_G(ngf, True) 56 | self.block4 = nn.Sequential( 57 | nn.ReLU(True), 58 | sn(nn.Conv2d(ngf, nc, 3, 1, 1, bias=False), n_power_iterations=5), 59 | nn.Tanh() 60 | ) 61 | 62 | 63 | def forward(self, x): 64 | out = self.linear(x) 65 | out = self.block1(out.view(-1, self.ngf, 4, 4)) 66 | out = self.block2(out) 67 | out = self.block3(out) 68 | out = self.block4(out) 69 | 70 | return out.view(-1, self.nc, 32, 32) 71 | 72 | 73 | class ResBlock_D(nn.Module): 74 | def __init__(self, nf, down=False, nc=3, first=False): 75 | super(ResBlock_D, self).__init__() 76 | 77 | self.nf = nf 78 | self.down = down 79 | self.nc = nc 80 | self.first = first 81 | nf_in = nc if first else nf 82 | 83 | self.relu1 = nn.ReLU(True) 84 | self.conv1 = sn(nn.Conv2d(nf_in, nf, 3, 1, 1, bias=False), n_power_iterations=5) 85 | self.relu2 = nn.ReLU(True) 86 | self.conv2 = sn(nn.Conv2d(nf, nf, 3, 1, 1, bias=False), n_power_iterations=5) 87 | 88 | self.conv_shortcut = sn(nn.Conv2d(nf_in, nf, 1, 1, 0, bias=False), n_power_iterations=5) 89 | 90 | def forward(self, x): 91 | out = x if self.first else self.relu1(x) 92 | out = self.conv1(out) 93 | out = self.relu2(out) 94 | out = self.conv2(out) 95 | 96 | if self.down: 97 | out = F.avg_pool2d(out, kernel_size=2, stride=2) 98 | shortcut = self.conv_shortcut(x) 99 | shortcut = F.avg_pool2d(shortcut, kernel_size=2, stride=2) 100 | 101 | else: 102 | shortcut = x 103 | 104 | out += shortcut 105 | 106 | return out 107 | 108 | 109 | class D_resnet(nn.Module): 110 | def __init__(self, nc, ndf): 111 | super(D_resnet, self).__init__() 112 | self.nc = nc 113 | self.ndf = ndf 114 | 115 | self.block1 = ResBlock_D(ndf, True, nc, True) 116 | self.block2 = ResBlock_D(ndf, True) 117 | self.block3 = ResBlock_D(ndf) 118 | self.block4 = ResBlock_D(ndf) 119 | self.relu = nn.ReLU(True) 120 | self.linear = sn(nn.Linear(ndf, 1), n_power_iterations=5) 121 | 122 | def forward(self, x): 123 | out = self.block1(x) 124 | out = self.block2(out) 125 | out = self.block3(out) 126 | out = self.block4(out) 127 | out = self.relu(out) 128 | out = out.sum(-1).sum(-1) 129 | out = self.linear(out.view(-1, self.ndf)) 130 | return out.view(-1, 1).squeeze() 131 | 132 | 133 | #-------------------- init ---------------------- 134 | def weights_init(m): 135 | classname = m.__class__.__name__ 136 | if classname.find('Conv') != -1: 137 | m.weight.data.normal_(0.0, 0.02) 138 | 139 | elif classname.find('BatchNorm2d') != -1: 140 | m.weight.data.normal_(1.0, 0.02) 141 | m.bias.data.fill_(0.0) 142 | 143 | elif classname.find('Linear') != -1: 144 | m.weight.data.normal_(0.0, 0.02) 145 | m.bias.data.fill_(0.0) 146 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | # basic functions 4 | import argparse 5 | import os 6 | import random 7 | import numpy as np 8 | import pandas as pd 9 | import matplotlib.pyplot as plt 10 | plt.switch_backend('agg') 11 | 12 | # torch functions 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.parallel 16 | import torch.backends.cudnn as cudnn 17 | import torch.optim as optim 18 | from torch.optim.lr_scheduler import MultiStepLR 19 | import torch.utils.data 20 | import torchvision.datasets as dset 21 | import torchvision.transforms as transforms 22 | import torchvision.utils as vutils 23 | 24 | # local functions 25 | from network import * 26 | from resnet import * 27 | from utils import poolSet, inceptionScore 28 | 29 | #-------------------------------------------------------------------- 30 | # input arguments 31 | parser = argparse.ArgumentParser(description='VGrow') 32 | parser.add_argument('--divergence', '-div', type=str, default='KL', help='KL | logd | JS | Jeffrey') 33 | parser.add_argument('--dataset', required=True, help='mnist | fashionmnist | cifar10') 34 | parser.add_argument('--dataroot', required=True, help='path to dataset') 35 | 36 | parser.add_argument('--gpuDevice', type=str, default='1', help='CUDA_VISIBLE_DEVICES') 37 | parser.add_argument('--workers', type=int, default=0, help='number of data loading workers') 38 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') 39 | parser.add_argument('--imageSize', type=int, default=32, help='input image size') 40 | 41 | parser.add_argument('--nz', type=int, default=128, help='size of the latent vector') 42 | parser.add_argument('--ngf', type=int, default=128) 43 | parser.add_argument('--ndf', type=int, default=128) 44 | 45 | parser.add_argument('--nEpoch', type=int, default=10000, help='maximum Outer Loops') 46 | parser.add_argument('--nDiter', type=int, default=1, help='number of D update') 47 | parser.add_argument('--nPiter', type=int, default=20, help='number of particle update') 48 | parser.add_argument('--nProj', type=int, default=20, help='number of G projection') 49 | parser.add_argument('--nPool', type=int, default=20, help='times of batch size for particle pool') 50 | parser.add_argument('--period', type=int, default=50, help='period of saving ckpts') 51 | 52 | parser.add_argument('--eta', type=float, default=0.5, help='learning rate for particle update') 53 | parser.add_argument('--lrg', type=float, default=0.0001, help='learning rate for G, default=0.0001') 54 | parser.add_argument('--lrd', type=float, default=0.0001, help='learning rate for D, default=0.0001') 55 | parser.add_argument('--decay_g', type=bool, default=True, help='lr_g decay') 56 | parser.add_argument('--decay_d', type=bool, default=True, help='lr_d decay') 57 | 58 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 59 | parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') 60 | parser.add_argument('--netG', default='', help='path to netG (to continue training)') 61 | parser.add_argument('--netD', default='', help='path to netD (to continue training)') 62 | parser.add_argument('--outf', default='./results', help='folder to output images and model checkpoints') 63 | parser.add_argument('--resume', type=bool, default=False, help='resume from checkpoint') 64 | parser.add_argument('--resume_epoch', type=int, default=0) 65 | parser.add_argument('--start_save', type=int, default=800) 66 | parser.add_argument('--manualSeed', type=int, help='manual seed') 67 | parser.add_argument('--increase_nProj', type=bool, default=True, help='increase the projection times') 68 | 69 | opt = parser.parse_args() 70 | print(opt) 71 | 72 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpuDevice 73 | 74 | try: 75 | os.makedirs(opt.outf) 76 | except OSError: 77 | pass 78 | 79 | try: 80 | os.mkdir('./projection_loss') 81 | except: 82 | pass 83 | 84 | if opt.manualSeed is None: 85 | opt.manualSeed = random.randint(1, 10000) 86 | print('Random Seed: ', opt.manualSeed) 87 | random.seed(opt.manualSeed) 88 | torch.manual_seed(opt.manualSeed) 89 | 90 | cudnn.benchmark = True 91 | 92 | train_transforms = transforms.Compose([ 93 | transforms.Resize(opt.imageSize), 94 | transforms.ToTensor(), 95 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 96 | ]) 97 | if opt.dataset == 'mnist': 98 | dataset = dset.MNIST(root=opt.dataroot, download=True, 99 | transform=train_transforms) 100 | nc = 1 101 | nclass = 10 102 | 103 | elif opt.dataset == 'fashionmnist': 104 | dataset = dset.FashionMNIST(root=opt.dataroot, download=True, 105 | transform=train_transforms) 106 | nc = 1 107 | nclass = 10 108 | 109 | elif opt.dataset == 'cifar10': 110 | dataset = dset.CIFAR10(root=opt.dataroot, download=True, 111 | transform=train_transforms) 112 | nc = 3 113 | nclass = 10 114 | 115 | else: 116 | raise NameError 117 | 118 | assert dataset 119 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, 120 | shuffle=True, num_workers=int(opt.workers)) 121 | 122 | device = torch.device('cuda:0' if torch.cuda.is_available() and not opt.cuda else 'cpu') 123 | ngpu = int(opt.ngpu) 124 | nz = int(opt.nz) 125 | ngf = int(opt.ngf) 126 | ndf = int(opt.ndf) 127 | eta = float(opt.eta) 128 | 129 | # nets 130 | netG = G_resnet(nc, ngf, nz) 131 | netD = D_resnet(nc, ndf) 132 | 133 | netG.apply(weights_init) 134 | netG.to(device) 135 | netD.apply(weights_init) 136 | netD.to(device) 137 | print('#-----------GAN initializd-----------#') 138 | 139 | if opt.resume: 140 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 141 | state = torch.load('./checkpoint/GBGAN-%s-%s-%s-ckpt.t7' % (opt.divergence, opt.dataset, str(opt.resume_epoch))) 142 | netG.load_state_dict(state['netG']) 143 | netD.load_state_dict(state['netD']) 144 | start_epoch = state['epoch'] + 1 145 | is_score = state['is_score'] 146 | best_is = state['best_is'] 147 | loss_G = state['loss_G'] 148 | print('#-----------Resumed from checkpoint-----------#') 149 | 150 | else: 151 | start_epoch = 0 152 | is_score = [] 153 | best_is = 0.0 154 | 155 | netIncept = PreActResNet18(nc) 156 | netIncept.to(device) 157 | netIncept = torch.nn.DataParallel(netIncept) 158 | 159 | if torch.cuda.is_available() and not opt.cuda: 160 | checkpoint = torch.load('./checkpoint/resnet18-%s-ckpt.t7' % opt.dataset) 161 | netIncept.load_state_dict(checkpoint['net']) 162 | 163 | else: 164 | checkpoint = torch.load('./checkpoint/resnet18-%s-ckpt.t7' % opt.dataset, map_location=lambda storage, loc: storage) 165 | netIncept.load_state_dict(checkpoint['net']) 166 | 167 | print('#------------Classifier load finished------------#') 168 | 169 | 170 | poolSize = opt.batchSize * opt.nPool 171 | 172 | z_b = torch.FloatTensor(opt.batchSize, nz).to(device) 173 | img_b = torch.FloatTensor(opt.batchSize, nc, opt.imageSize, opt.imageSize).to(device) 174 | p_z = torch.FloatTensor(poolSize, nz).to(device) 175 | p_img = torch.FloatTensor(poolSize, nc, opt.imageSize, opt.imageSize).to(device) 176 | 177 | show_z_b = torch.FloatTensor(64, nz).to(device) 178 | eval_z_b = torch.FloatTensor(250, nz).to(device) 179 | 180 | # set optimizer 181 | optim_D = optim.RMSprop(netD.parameters(), lr=opt.lrd) 182 | optim_G = optim.RMSprop(netG.parameters(), lr=opt.lrg) 183 | 184 | if opt.dataset == 'mnist': 185 | scheduler_D = MultiStepLR(optim_D, milestones=[400, 800], gamma=0.5) 186 | scheduler_G = MultiStepLR(optim_G, milestones=[400, 800], gamma=0.5) 187 | 188 | elif opt.dataset == 'fashionmnist': 189 | scheduler_D = MultiStepLR(optim_D, milestones=[400, 800, 1200], gamma=0.5) 190 | scheduler_G = MultiStepLR(optim_G, milestones=[400, 800, 1200], gamma=0.5) 191 | 192 | elif opt.dataset == 'cifar10': 193 | scheduler_D = MultiStepLR(optim_D, milestones=[800, 1600, 2400], gamma=0.5) 194 | scheduler_G = MultiStepLR(optim_G, milestones=[800, 1600, 2400], gamma=0.5) 195 | 196 | # set criterion 197 | criterion_G = nn.MSELoss() 198 | 199 | def get_nProj_mfm(epoch): 200 | if epoch < 200: 201 | nProj_t = 5 202 | elif 199 < epoch < 1000: 203 | nProj_t = 10 204 | elif 999 < epoch < 1500: 205 | nProj_t = 15 206 | elif 1499 < epoch: 207 | nProj_t = 20 208 | 209 | return nProj_t 210 | 211 | def get_nProj_cf(epoch): 212 | if epoch < 600: 213 | nProj_t = 5 214 | elif 599 < epoch < 2000: 215 | nProj_t = 10 216 | elif 1999 < epoch < 3000: 217 | nProj_t = 15 218 | elif 2999 < epoch: 219 | nProj_t = 20 220 | 221 | return nProj_t 222 | 223 | def get_nProj_t(epoch): 224 | if opt.dataset == 'mnist' or 'fashionmnist': 225 | nProj_t = get_nProj_mfm(epoch) 226 | elif opt.dataset == 'cifar10': 227 | nProj_t = get_nProj_cf(epoch) 228 | else: 229 | raise NameError 230 | 231 | return nProj_t 232 | 233 | #--------------------------- main function ---------------------------# 234 | real_show, _ = next(iter(dataloader)) 235 | vutils.save_image(real_show / 2 + 0.5, './results/real-%s.png' % opt.dataset, padding=0) 236 | 237 | for epoch in range(start_epoch, start_epoch + opt.nEpoch): 238 | # decay lr 239 | if opt.decay_d: 240 | scheduler_D.step() 241 | if opt.decay_g: 242 | scheduler_G.step() 243 | 244 | # input_pool 245 | netD.train() 246 | netG.eval() 247 | p_z.normal_() 248 | p_img.copy_(netG(p_z).detach()) 249 | 250 | for t in range(opt.nPiter): 251 | 252 | for _ in range(opt.nDiter): 253 | 254 | # Update D 255 | netD.zero_grad() 256 | # real 257 | real_img, _ = next(iter(dataloader)) 258 | img_b.copy_(real_img.to(device)) 259 | real_D_err = torch.log(1 + torch.exp(-netD(img_b))).mean() 260 | real_D_err.backward() 261 | 262 | # fake 263 | z_b_idx = random.sample(range(poolSize), opt.batchSize) 264 | img_b.copy_(p_img[z_b_idx]) 265 | fake_D_err = torch.log(1 + torch.exp(netD(img_b))).mean() 266 | fake_D_err.backward() 267 | 268 | optim_D.step() 269 | 270 | # update particle pool 271 | p_img_t = p_img.clone().to(device) 272 | 273 | p_img_t.requires_grad_(True) 274 | if p_img_t.grad is not None: 275 | p_img_t.grad.zero_() 276 | fake_D_score = netD(p_img_t) 277 | 278 | # set s(x) 279 | if opt.divergence == 'KL': 280 | s = torch.ones_like(fake_D_score.detach()) 281 | 282 | elif opt.divergence == 'logd': 283 | s = 1 / (1 + fake_D_score.detach().exp()) 284 | 285 | elif opt.divergence == 'JS': 286 | s = 1 / (1 + 1 / fake_D_score.detach().exp()) 287 | 288 | elif opt.divergence == 'Jeffrey': 289 | s = 1 + fake_D_score.detach().exp() 290 | 291 | else: 292 | raise NameError 293 | 294 | s.unsqueeze_(1).unsqueeze_(2).unsqueeze_(3).expand_as(p_img_t) 295 | fake_D_score.backward(torch.ones(len(p_img_t)).to(device)) 296 | p_img = torch.clamp(p_img + eta * s * p_img_t.grad, -1, 1) 297 | 298 | # update G 299 | netG.train() 300 | netD.eval() 301 | poolset = poolSet(p_z.cpu(), p_img.cpu()) 302 | poolloader = torch.utils.data.DataLoader(poolset, batch_size=opt.batchSize, shuffle=True, num_workers=opt.workers) 303 | 304 | loss_G = [] 305 | 306 | # set nProj_t 307 | if opt.increase_nProj: 308 | nProj_t = get_nProj_t(epoch) 309 | else: 310 | nProj_t = opt.nProj 311 | 312 | for _ in range(nProj_t): 313 | 314 | loss_G_t = [] 315 | for _, data_ in enumerate(poolloader, 0): 316 | netG.zero_grad() 317 | 318 | input_, target_ = data_ 319 | pred_ = netG(input_.to(device)) 320 | loss = criterion_G(pred_, target_.to(device)) 321 | loss.backward() 322 | 323 | optim_G.step() 324 | loss_G_t.append(loss.detach().cpu().item()) 325 | 326 | loss_G.append(np.mean(loss_G_t)) 327 | 328 | vutils.save_image(target_ / 2 + 0.5, './results/particle-%s-%s-%s-%s.png' 329 | % (str(epoch).zfill(4), opt.divergence, opt.dataset, str(opt.eta)), padding=0) 330 | print('Epoch(%s/%s)%d: %.4fe-4 | %.4fe-4 | %.4f' 331 | % (opt.divergence, opt.dataset, epoch, real_D_err*10000,fake_D_err*10000, p_img_t.grad.norm(p=2))) 332 | 333 | #----------------------------------------------------------------- 334 | if epoch % opt.period == 0: 335 | fig = plt.figure() 336 | plt.style.use('ggplot') 337 | plt.plot(loss_G, label=opt.divergence) 338 | plt.xlabel('Loop') 339 | plt.ylabel('Projection Loss') 340 | plt.legend() 341 | fig.savefig('./projection_loss/projection' + str(epoch).zfill(4) + '.png') 342 | plt.close() 343 | 344 | # show image 345 | netG.eval() 346 | show_z_b.normal_() 347 | fake_img = netG(show_z_b) 348 | vutils.save_image(fake_img.detach().cpu() / 2 + 0.5, './results/fake-%s-%s-%s-%s.png' 349 | % (str(epoch).zfill(4), opt.divergence, opt.dataset, str(opt.eta)), padding=0) 350 | 351 | # inception score 352 | is_score.append(inceptionScore(netIncept, netG, device, nz, nclass)) 353 | print('[%d] Inception Score is: %.4f' % (epoch, is_score[-1])) 354 | best_is = max(is_score[-1], best_is) 355 | 356 | fig = plt.figure() 357 | plt.style.use('ggplot') 358 | plt.plot(opt.period * (np.arange(epoch//opt.period + 1)), is_score, label=opt.divergence) 359 | plt.xlabel('Loop') 360 | plt.ylabel('Inception Score') 361 | plt.legend() 362 | fig.savefig('IS-%s-%s.png' % (opt.divergence, opt.dataset)) 363 | plt.close() 364 | 365 | if best_is == is_score[-1]: 366 | print('Save the best Inception Score: %.4f' % is_score[-1]) 367 | else: 368 | pass 369 | 370 | if epoch > opt.start_save and epoch % 50 == 0: 371 | state = { 372 | 'netG': netG.state_dict(), 373 | 'netD': netD.state_dict(), 374 | 'is_score': is_score, 375 | 'loss_G': loss_G, 376 | 'epoch': epoch, 377 | 'best_is': best_is 378 | } 379 | torch.save(state, './checkpoint/GBGAN-%s-%s-%s-ckpt.t7' % (opt.divergence, opt.dataset, str(epoch))) 380 | 381 | # save IS 382 | if epoch % 500 == 0: 383 | dataframe = pd.DataFrame({'IS-%s' % opt.divergence: is_score}) 384 | dataframe.to_csv('is-%s-%s.csv' % (opt.divergence, opt.dataset), sep=',') 385 | 386 | 387 | 388 | --------------------------------------------------------------------------------