├── LICENSE ├── README.md └── main.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dcgan_vae_pytorch 2 | dcgan combined with vae in pytorch! 3 | 4 | this code is based on [pytorch/examples](https://github.com/pytorch/examples) and [staturecrane/dcgan_vae_torch](https://github.com/staturecrane/dcgan_vae_torch) 5 | 6 | The original artical can be found [here](https://arxiv.org/abs/1512.09300) 7 | ## Requirements 8 | * torch 9 | * torchvision 10 | * visdom 11 | * (optional) lmdb 12 | 13 | ## Usage 14 | to start visdom: 15 | ``` 16 | python -m visdom.server 17 | ``` 18 | 19 | 20 | to start the training: 21 | ``` 22 | usage: main.py [-h] --dataset DATASET --dataroot DATAROOT [--workers WORKERS] 23 | [--batchSize BATCHSIZE] [--imageSize IMAGESIZE] [--nz NZ] 24 | [--ngf NGF] [--ndf NDF] [--niter NITER] [--saveInt SAVEINT] [--lr LR] 25 | [--beta1 BETA1] [--cuda] [--ngpu NGPU] [--netG NETG] 26 | [--netD NETD] 27 | 28 | optional arguments: 29 | -h, --help show this help message and exit 30 | --dataset DATASET cifar10 | lsun | imagenet | folder | lfw 31 | --dataroot DATAROOT path to dataset 32 | --workers WORKERS number of data loading workers 33 | --batchSize BATCHSIZE 34 | input batch size 35 | --imageSize IMAGESIZE 36 | the height / width of the input image to network 37 | --nz NZ size of the latent z vector 38 | --ngf NGF 39 | --ndf NDF 40 | --niter NITER number of epochs to train for 41 | --saveInt SAVEINT number of epochs between checkpoints 42 | --lr LR learning rate, default=0.0002 43 | --beta1 BETA1 beta1 for adam. default=0.5 44 | --cuda enables cuda 45 | --ngpu NGPU number of GPUs to use 46 | --netG NETG path to netG (to continue training) 47 | --netD NETD path to netD (to continue training) 48 | ``` 49 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torch.legacy.nn as lnn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim as optim 12 | import torch.utils.data 13 | import torchvision.datasets as dset 14 | import torchvision.transforms as transforms 15 | import torchvision.utils as vutils 16 | import visdom 17 | from torch.autograd import Variable 18 | 19 | vis = visdom.Visdom() 20 | vis.env = 'vae_dcgan' 21 | 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw ') 25 | parser.add_argument('--dataroot', required=True, help='path to dataset') 26 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 27 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') 28 | parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') 29 | parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') 30 | parser.add_argument('--ngf', type=int, default=64) 31 | parser.add_argument('--ndf', type=int, default=64) 32 | parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for') 33 | parser.add_argument('--saveInt', type=int, default=25, help='number of epochs between checkpoints') 34 | parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') 35 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 36 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 37 | parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') 38 | parser.add_argument('--netG', default='', help="path to netG (to continue training)") 39 | parser.add_argument('--netD', default='', help="path to netD (to continue training)") 40 | parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints') 41 | parser.add_argument('--manualSeed', type=int, help='manual seed') 42 | 43 | opt = parser.parse_args() 44 | print(opt) 45 | 46 | try: 47 | os.makedirs(opt.outf) 48 | except OSError: 49 | pass 50 | 51 | if opt.manualSeed is None: 52 | opt.manualSeed = random.randint(1, 10000) 53 | print("Random Seed: ", opt.manualSeed) 54 | random.seed(opt.manualSeed) 55 | torch.manual_seed(opt.manualSeed) 56 | if opt.cuda: 57 | torch.cuda.manual_seed_all(opt.manualSeed) 58 | 59 | cudnn.benchmark = True 60 | 61 | if torch.cuda.is_available() and not opt.cuda: 62 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 63 | 64 | if opt.dataset in ['imagenet', 'folder', 'lfw']: 65 | # folder dataset 66 | dataset = dset.ImageFolder(root=opt.dataroot, 67 | transform=transforms.Compose([ 68 | transforms.Scale(opt.imageSize), 69 | transforms.CenterCrop(opt.imageSize), 70 | transforms.ToTensor(), 71 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 72 | ])) 73 | elif opt.dataset == 'lsun': 74 | dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'], 75 | transform=transforms.Compose([ 76 | transforms.Scale(opt.imageSize), 77 | transforms.CenterCrop(opt.imageSize), 78 | transforms.ToTensor(), 79 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 80 | ])) 81 | elif opt.dataset == 'cifar10': 82 | dataset = dset.CIFAR10(root=opt.dataroot, download=True, 83 | transform=transforms.Compose([ 84 | transforms.Scale(opt.imageSize), 85 | transforms.ToTensor(), 86 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 87 | ]) 88 | ) 89 | assert dataset 90 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, 91 | shuffle=True, num_workers=int(opt.workers)) 92 | 93 | ngpu = int(opt.ngpu) 94 | nz = int(opt.nz) 95 | ngf = int(opt.ngf) 96 | ndf = int(opt.ndf) 97 | nc = 3 98 | 99 | 100 | # custom weights initialization called on netG and netD 101 | def weights_init(m): 102 | classname = m.__class__.__name__ 103 | if classname.find('Conv') != -1: 104 | m.weight.data.normal_(0.0, 0.02) 105 | elif classname.find('BatchNorm') != -1: 106 | m.weight.data.normal_(1.0, 0.02) 107 | m.bias.data.fill_(0) 108 | 109 | class _Sampler(nn.Module): 110 | def __init__(self): 111 | super(_Sampler, self).__init__() 112 | 113 | def forward(self,input): 114 | mu = input[0] 115 | logvar = input[1] 116 | 117 | std = logvar.mul(0.5).exp_() #calculate the STDEV 118 | if opt.cuda: 119 | eps = torch.cuda.FloatTensor(std.size()).normal_() #random normalized noise 120 | else: 121 | eps = torch.FloatTensor(std.size()).normal_() #random normalized noise 122 | eps = Variable(eps) 123 | return eps.mul(std).add_(mu) 124 | 125 | 126 | class _Encoder(nn.Module): 127 | def __init__(self,imageSize): 128 | super(_Encoder, self).__init__() 129 | 130 | n = math.log2(imageSize) 131 | 132 | assert n==round(n),'imageSize must be a power of 2' 133 | assert n>=3,'imageSize must be at least 8' 134 | n=int(n) 135 | 136 | 137 | self.conv1 = nn.Conv2d(ngf * 2**(n-3), nz, 4) 138 | self.conv2 = nn.Conv2d(ngf * 2**(n-3), nz, 4) 139 | 140 | self.encoder = nn.Sequential() 141 | # input is (nc) x 64 x 64 142 | self.encoder.add_module('input-conv',nn.Conv2d(nc, ngf, 4, 2, 1, bias=False)) 143 | self.encoder.add_module('input-relu',nn.LeakyReLU(0.2, inplace=True)) 144 | for i in range(n-3): 145 | # state size. (ngf) x 32 x 32 146 | self.encoder.add_module('pyramid.{0}-{1}.conv'.format(ngf*2**i, ngf * 2**(i+1)), nn.Conv2d(ngf*2**(i), ngf * 2**(i+1), 4, 2, 1, bias=False)) 147 | self.encoder.add_module('pyramid.{0}.batchnorm'.format(ngf * 2**(i+1)), nn.BatchNorm2d(ngf * 2**(i+1))) 148 | self.encoder.add_module('pyramid.{0}.relu'.format(ngf * 2**(i+1)), nn.LeakyReLU(0.2, inplace=True)) 149 | 150 | # state size. (ngf*8) x 4 x 4 151 | 152 | def forward(self,input): 153 | output = self.encoder(input) 154 | return [self.conv1(output),self.conv2(output)] 155 | 156 | 157 | class _netG(nn.Module): 158 | def __init__(self, imageSize, ngpu): 159 | super(_netG, self).__init__() 160 | self.ngpu = ngpu 161 | self.encoder = _Encoder(imageSize) 162 | self.sampler = _Sampler() 163 | 164 | n = math.log2(imageSize) 165 | 166 | assert n==round(n),'imageSize must be a power of 2' 167 | assert n>=3,'imageSize must be at least 8' 168 | n=int(n) 169 | 170 | self.decoder = nn.Sequential() 171 | # input is Z, going into a convolution 172 | self.decoder.add_module('input-conv', nn.ConvTranspose2d(nz, ngf * 2**(n-3), 4, 1, 0, bias=False)) 173 | self.decoder.add_module('input-batchnorm', nn.BatchNorm2d(ngf * 2**(n-3))) 174 | self.decoder.add_module('input-relu', nn.LeakyReLU(0.2, inplace=True)) 175 | 176 | # state size. (ngf * 2**(n-3)) x 4 x 4 177 | 178 | for i in range(n-3, 0, -1): 179 | self.decoder.add_module('pyramid.{0}-{1}.conv'.format(ngf*2**i, ngf * 2**(i-1)),nn.ConvTranspose2d(ngf * 2**i, ngf * 2**(i-1), 4, 2, 1, bias=False)) 180 | self.decoder.add_module('pyramid.{0}.batchnorm'.format(ngf * 2**(i-1)), nn.BatchNorm2d(ngf * 2**(i-1))) 181 | self.decoder.add_module('pyramid.{0}.relu'.format(ngf * 2**(i-1)), nn.LeakyReLU(0.2, inplace=True)) 182 | 183 | self.decoder.add_module('ouput-conv', nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False)) 184 | self.decoder.add_module('output-tanh', nn.Tanh()) 185 | 186 | 187 | def forward(self, input): 188 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 189 | output = nn.parallel.data_parallel(self.encoder, input, range(self.ngpu)) 190 | output = nn.parallel.data_parallel(self.sampler, output, range(self.ngpu)) 191 | output = nn.parallel.data_parallel(self.decoder, output, range(self.ngpu)) 192 | else: 193 | output = self.encoder(input) 194 | output = self.sampler(output) 195 | output = self.decoder(output) 196 | return output 197 | 198 | def make_cuda(self): 199 | self.encoder.cuda() 200 | self.sampler.cuda() 201 | self.decoder.cuda() 202 | 203 | netG = _netG(opt.imageSize,ngpu) 204 | netG.apply(weights_init) 205 | if opt.netG != '': 206 | netG.load_state_dict(torch.load(opt.netG)) 207 | print(netG) 208 | 209 | 210 | class _netD(nn.Module): 211 | def __init__(self, imageSize, ngpu): 212 | super(_netD, self).__init__() 213 | self.ngpu = ngpu 214 | n = math.log2(imageSize) 215 | 216 | assert n==round(n),'imageSize must be a power of 2' 217 | assert n>=3,'imageSize must be at least 8' 218 | n=int(n) 219 | self.main = nn.Sequential() 220 | 221 | # input is (nc) x 64 x 64 222 | self.main.add_module('input-conv', nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)) 223 | self.main.add_module('relu', nn.LeakyReLU(0.2, inplace=True)) 224 | 225 | # state size. (ndf) x 32 x 32 226 | for i in range(n-3): 227 | self.main.add_module('pyramid.{0}-{1}.conv'.format(ngf*2**(i), ngf * 2**(i+1)), nn.Conv2d(ndf * 2 ** (i), ndf * 2 ** (i+1), 4, 2, 1, bias=False)) 228 | self.main.add_module('pyramid.{0}.batchnorm'.format(ngf * 2**(i+1)), nn.BatchNorm2d(ndf * 2 ** (i+1))) 229 | self.main.add_module('pyramid.{0}.relu'.format(ngf * 2**(i+1)), nn.LeakyReLU(0.2, inplace=True)) 230 | 231 | self.main.add_module('output-conv', nn.Conv2d(ndf * 2**(n-3), 1, 4, 1, 0, bias=False)) 232 | self.main.add_module('output-sigmoid', nn.Sigmoid()) 233 | 234 | 235 | def forward(self, input): 236 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 237 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 238 | else: 239 | output = self.main(input) 240 | 241 | return output.view(-1, 1) 242 | 243 | 244 | netD = _netD(opt.imageSize,ngpu) 245 | netD.apply(weights_init) 246 | if opt.netD != '': 247 | netD.load_state_dict(torch.load(opt.netD)) 248 | print(netD) 249 | 250 | criterion = nn.BCELoss() 251 | MSECriterion = nn.MSELoss() 252 | 253 | input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize) 254 | noise = torch.FloatTensor(opt.batchSize, nz, 1, 1) 255 | fixed_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1) 256 | label = torch.FloatTensor(opt.batchSize) 257 | real_label = 1 258 | fake_label = 0 259 | 260 | if opt.cuda: 261 | netD.cuda() 262 | netG.make_cuda() 263 | criterion.cuda() 264 | MSECriterion.cuda() 265 | input, label = input.cuda(), label.cuda() 266 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 267 | 268 | input = Variable(input) 269 | label = Variable(label) 270 | noise = Variable(noise) 271 | fixed_noise = Variable(fixed_noise) 272 | 273 | # setup optimizer 274 | optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 275 | optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 276 | 277 | gen_win = None 278 | rec_win = None 279 | 280 | for epoch in range(opt.niter): 281 | for i, data in enumerate(dataloader, 0): 282 | ############################ 283 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 284 | ########################### 285 | # train with real 286 | netD.zero_grad() 287 | real_cpu, _ = data 288 | batch_size = real_cpu.size(0) 289 | input.data.resize_(real_cpu.size()).copy_(real_cpu) 290 | label.data.resize_(real_cpu.size(0)).fill_(real_label) 291 | 292 | output = netD(input) 293 | errD_real = criterion(output, label) 294 | errD_real.backward() 295 | D_x = output.data.mean() 296 | 297 | # train with fake 298 | noise.data.resize_(batch_size, nz, 1, 1) 299 | noise.data.normal_(0, 1) 300 | gen = netG.decoder(noise) 301 | gen_win = vis.image(gen.data[0].cpu()*0.5+0.5,win = gen_win) 302 | label.data.fill_(fake_label) 303 | output = netD(gen.detach()) 304 | errD_fake = criterion(output, label) 305 | errD_fake.backward() 306 | D_G_z1 = output.data.mean() 307 | errD = errD_real + errD_fake 308 | optimizerD.step() 309 | ############################ 310 | # (2) Update G network: VAE 311 | ########################### 312 | 313 | netG.zero_grad() 314 | 315 | encoded = netG.encoder(input) 316 | mu = encoded[0] 317 | logvar = encoded[1] 318 | 319 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 320 | KLD = torch.sum(KLD_element).mul_(-0.5) 321 | 322 | sampled = netG.sampler(encoded) 323 | rec = netG.decoder(sampled) 324 | rec_win = vis.image(rec.data[0].cpu()*0.5+0.5,win = rec_win) 325 | 326 | MSEerr = MSECriterion(rec,input) 327 | 328 | VAEerr = KLD + MSEerr; 329 | VAEerr.backward() 330 | optimizerG.step() 331 | 332 | ############################ 333 | # (3) Update G network: maximize log(D(G(z))) 334 | ########################### 335 | 336 | label.data.fill_(real_label) # fake labels are real for generator cost 337 | 338 | rec = netG(input) # this tensor is freed from mem at this point 339 | output = netD(rec) 340 | errG = criterion(output, label) 341 | errG.backward() 342 | D_G_z2 = output.data.mean() 343 | optimizerG.step() 344 | 345 | print('[%d/%d][%d/%d] Loss_VAE: %.4f Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' 346 | % (epoch, opt.niter, i, len(dataloader), 347 | VAEerr.data[0], errD.data[0], errG.data[0], D_x, D_G_z1, D_G_z2)) 348 | 349 | if epoch%opt.saveInt == 0 and epoch!=0: 350 | torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch)) 351 | torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch)) 352 | --------------------------------------------------------------------------------