├── LICENSE.md ├── README.md ├── generate.py ├── imgs ├── compare_dcgan.png └── w_combined.png ├── main.py ├── models ├── __init__.py ├── dcgan.py └── mlp.py └── requirements.txt /LICENSE.md: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Martin Arjovsky (NYU), Soumith Chintala (Facebook), Leon Bottou (Facebook) 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Wasserstein GAN 2 | =============== 3 | 4 | Code accompanying the paper ["Wasserstein GAN"](https://arxiv.org/abs/1701.07875) 5 | 6 | ## A few notes 7 | 8 | - The first time running on the LSUN dataset it can take a long time (up to an hour) to create the dataloader. After the first run a small cache file will be created and the process should take a matter of seconds. The cache is a list of indices in the lmdb database (of LSUN) 9 | - The only addition to the code (that we forgot, and will add, on the paper) are the [lines 163-166 of main.py](https://github.com/martinarjovsky/WassersteinGAN/blob/master/main.py#L163-L166). These lines act only on the first 25 generator iterations or very sporadically (once every 500 generator iterations). In such a case, they set the number of iterations on the critic to 100 instead of the default 5. This helps to start with the critic at optimum even in the first iterations. There shouldn't be a major difference in performance, but it can help, especially when visualizing learning curves (since otherwise you'd see the loss going up until the critic is properly trained). This is also why the first 25 iterations take significantly longer than the rest of the training as well. 10 | - If your learning curve suddenly takes a big drop take a look at [this](https://github.com/martinarjovsky/WassersteinGAN/issues/2). It's a problem when the critic fails to be close to optimum, and hence its error stops being a good Wasserstein estimate. Known causes are high learning rates and momentum, and anything that helps the critic get back on track is likely to help with the issue. 11 | 12 | ## Prerequisites 13 | 14 | - Computer with Linux or OSX 15 | - [PyTorch](http://pytorch.org) 16 | - For training, an NVIDIA GPU is strongly recommended for speed. CPU is supported but training is very slow. 17 | 18 | Two main empirical claims: 19 | 20 | ### Generator sample quality correlates with discriminator loss 21 | 22 | ![gensample](imgs/w_combined.png "sample quality correlates with discriminator loss") 23 | 24 | ### Improved model stability 25 | 26 | ![stability](imgs/compare_dcgan.png "stability") 27 | 28 | 29 | ## Reproducing LSUN experiments 30 | 31 | **With DCGAN:** 32 | 33 | ```bash 34 | python main.py --dataset lsun --dataroot [lsun-train-folder] --cuda 35 | ``` 36 | 37 | **With MLP:** 38 | 39 | ```bash 40 | python main.py --mlp_G --ngf 512 41 | ``` 42 | 43 | Generated samples will be in the `samples` folder. 44 | 45 | If you plot the value `-Loss_D`, then you can reproduce the curves from the paper. The curves from the paper (as mentioned in the paper) have a median filter applied to them: 46 | 47 | ```python 48 | med_filtered_loss = scipy.signal.medfilt(-Loss_D, dtype='float64'), 101) 49 | ``` 50 | 51 | More improved README in the works. 52 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim as optim 9 | import torch.utils.data 10 | import torchvision.datasets as dset 11 | import torchvision.transforms as transforms 12 | import torchvision.utils as vutils 13 | from torch.autograd import Variable 14 | import os 15 | import json 16 | 17 | import models.dcgan as dcgan 18 | import models.mlp as mlp 19 | 20 | if __name__=="__main__": 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('-c', '--config', required=True, type=str, help='path to generator config .json file') 24 | parser.add_argument('-w', '--weights', required=True, type=str, help='path to generator weights .pth file') 25 | parser.add_argument('-o', '--output_dir', required=True, type=str, help="path to to output directory") 26 | parser.add_argument('-n', '--nimages', required=True, type=int, help="number of images to generate", default=1) 27 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 28 | opt = parser.parse_args() 29 | 30 | with open(opt.config, 'r') as gencfg: 31 | generator_config = json.loads(gencfg.read()) 32 | 33 | imageSize = generator_config["imageSize"] 34 | nz = generator_config["nz"] 35 | nc = generator_config["nc"] 36 | ngf = generator_config["ngf"] 37 | noBN = generator_config["noBN"] 38 | ngpu = generator_config["ngpu"] 39 | mlp_G = generator_config["mlp_G"] 40 | n_extra_layers = generator_config["n_extra_layers"] 41 | 42 | if noBN: 43 | netG = dcgan.DCGAN_G_nobn(imageSize, nz, nc, ngf, ngpu, n_extra_layers) 44 | elif mlp_G: 45 | netG = mlp.MLP_G(imageSize, nz, nc, ngf, ngpu) 46 | else: 47 | netG = dcgan.DCGAN_G(imageSize, nz, nc, ngf, ngpu, n_extra_layers) 48 | 49 | # load weights 50 | netG.load_state_dict(torch.load(opt.weights)) 51 | 52 | # initialize noise 53 | fixed_noise = torch.FloatTensor(opt.nimages, nz, 1, 1).normal_(0, 1) 54 | 55 | if opt.cuda: 56 | netG.cuda() 57 | fixed_noise = fixed_noise.cuda() 58 | 59 | fake = netG(fixed_noise) 60 | fake.data = fake.data.mul(0.5).add(0.5) 61 | 62 | for i in range(opt.nimages): 63 | vutils.save_image(fake.data[i, ...].reshape((1, nc, imageSize, imageSize)), os.path.join(opt.output_dir, "generated_%02d.png"%i)) 64 | -------------------------------------------------------------------------------- /imgs/compare_dcgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinarjovsky/WassersteinGAN/f7a01e82007ea408647c451b9e1c8f1932a3db67/imgs/compare_dcgan.png -------------------------------------------------------------------------------- /imgs/w_combined.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinarjovsky/WassersteinGAN/f7a01e82007ea408647c451b9e1c8f1932a3db67/imgs/w_combined.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim as optim 9 | import torch.utils.data 10 | import torchvision.datasets as dset 11 | import torchvision.transforms as transforms 12 | import torchvision.utils as vutils 13 | from torch.autograd import Variable 14 | import os 15 | import json 16 | 17 | import models.dcgan as dcgan 18 | import models.mlp as mlp 19 | 20 | if __name__=="__main__": 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw ') 24 | parser.add_argument('--dataroot', required=True, help='path to dataset') 25 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 26 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') 27 | parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') 28 | parser.add_argument('--nc', type=int, default=3, help='input image channels') 29 | parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') 30 | parser.add_argument('--ngf', type=int, default=64) 31 | parser.add_argument('--ndf', type=int, default=64) 32 | parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for') 33 | parser.add_argument('--lrD', type=float, default=0.00005, help='learning rate for Critic, default=0.00005') 34 | parser.add_argument('--lrG', type=float, default=0.00005, help='learning rate for Generator, default=0.00005') 35 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 36 | parser.add_argument('--cuda' , action='store_true', help='enables cuda') 37 | parser.add_argument('--ngpu' , type=int, default=1, help='number of GPUs to use') 38 | parser.add_argument('--netG', default='', help="path to netG (to continue training)") 39 | parser.add_argument('--netD', default='', help="path to netD (to continue training)") 40 | parser.add_argument('--clamp_lower', type=float, default=-0.01) 41 | parser.add_argument('--clamp_upper', type=float, default=0.01) 42 | parser.add_argument('--Diters', type=int, default=5, help='number of D iters per each G iter') 43 | parser.add_argument('--noBN', action='store_true', help='use batchnorm or not (only for DCGAN)') 44 | parser.add_argument('--mlp_G', action='store_true', help='use MLP for G') 45 | parser.add_argument('--mlp_D', action='store_true', help='use MLP for D') 46 | parser.add_argument('--n_extra_layers', type=int, default=0, help='Number of extra layers on gen and disc') 47 | parser.add_argument('--experiment', default=None, help='Where to store samples and models') 48 | parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)') 49 | opt = parser.parse_args() 50 | print(opt) 51 | 52 | if opt.experiment is None: 53 | opt.experiment = 'samples' 54 | os.system('mkdir {0}'.format(opt.experiment)) 55 | 56 | opt.manualSeed = random.randint(1, 10000) # fix seed 57 | print("Random Seed: ", opt.manualSeed) 58 | random.seed(opt.manualSeed) 59 | torch.manual_seed(opt.manualSeed) 60 | 61 | cudnn.benchmark = True 62 | 63 | if torch.cuda.is_available() and not opt.cuda: 64 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 65 | 66 | if opt.dataset in ['imagenet', 'folder', 'lfw']: 67 | # folder dataset 68 | dataset = dset.ImageFolder(root=opt.dataroot, 69 | transform=transforms.Compose([ 70 | transforms.Scale(opt.imageSize), 71 | transforms.CenterCrop(opt.imageSize), 72 | transforms.ToTensor(), 73 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 74 | ])) 75 | elif opt.dataset == 'lsun': 76 | dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'], 77 | transform=transforms.Compose([ 78 | transforms.Scale(opt.imageSize), 79 | transforms.CenterCrop(opt.imageSize), 80 | transforms.ToTensor(), 81 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 82 | ])) 83 | elif opt.dataset == 'cifar10': 84 | dataset = dset.CIFAR10(root=opt.dataroot, download=True, 85 | transform=transforms.Compose([ 86 | transforms.Scale(opt.imageSize), 87 | transforms.ToTensor(), 88 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 89 | ]) 90 | ) 91 | assert dataset 92 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, 93 | shuffle=True, num_workers=int(opt.workers)) 94 | 95 | ngpu = int(opt.ngpu) 96 | nz = int(opt.nz) 97 | ngf = int(opt.ngf) 98 | ndf = int(opt.ndf) 99 | nc = int(opt.nc) 100 | n_extra_layers = int(opt.n_extra_layers) 101 | 102 | # write out generator config to generate images together wth training checkpoints (.pth) 103 | generator_config = {"imageSize": opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf, "ngpu": ngpu, "n_extra_layers": n_extra_layers, "noBN": opt.noBN, "mlp_G": opt.mlp_G} 104 | with open(os.path.join(opt.experiment, "generator_config.json"), 'w') as gcfg: 105 | gcfg.write(json.dumps(generator_config)+"\n") 106 | 107 | # custom weights initialization called on netG and netD 108 | def weights_init(m): 109 | classname = m.__class__.__name__ 110 | if classname.find('Conv') != -1: 111 | m.weight.data.normal_(0.0, 0.02) 112 | elif classname.find('BatchNorm') != -1: 113 | m.weight.data.normal_(1.0, 0.02) 114 | m.bias.data.fill_(0) 115 | 116 | if opt.noBN: 117 | netG = dcgan.DCGAN_G_nobn(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers) 118 | elif opt.mlp_G: 119 | netG = mlp.MLP_G(opt.imageSize, nz, nc, ngf, ngpu) 120 | else: 121 | netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers) 122 | 123 | # write out generator config to generate images together wth training checkpoints (.pth) 124 | generator_config = {"imageSize": opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf, "ngpu": ngpu, "n_extra_layers": n_extra_layers, "noBN": opt.noBN, "mlp_G": opt.mlp_G} 125 | with open(os.path.join(opt.experiment, "generator_config.json"), 'w') as gcfg: 126 | gcfg.write(json.dumps(generator_config)+"\n") 127 | 128 | netG.apply(weights_init) 129 | if opt.netG != '': # load checkpoint if needed 130 | netG.load_state_dict(torch.load(opt.netG)) 131 | print(netG) 132 | 133 | if opt.mlp_D: 134 | netD = mlp.MLP_D(opt.imageSize, nz, nc, ndf, ngpu) 135 | else: 136 | netD = dcgan.DCGAN_D(opt.imageSize, nz, nc, ndf, ngpu, n_extra_layers) 137 | netD.apply(weights_init) 138 | 139 | if opt.netD != '': 140 | netD.load_state_dict(torch.load(opt.netD)) 141 | print(netD) 142 | 143 | input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize) 144 | noise = torch.FloatTensor(opt.batchSize, nz, 1, 1) 145 | fixed_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1) 146 | one = torch.FloatTensor([1]) 147 | mone = one * -1 148 | 149 | if opt.cuda: 150 | netD.cuda() 151 | netG.cuda() 152 | input = input.cuda() 153 | one, mone = one.cuda(), mone.cuda() 154 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 155 | 156 | # setup optimizer 157 | if opt.adam: 158 | optimizerD = optim.Adam(netD.parameters(), lr=opt.lrD, betas=(opt.beta1, 0.999)) 159 | optimizerG = optim.Adam(netG.parameters(), lr=opt.lrG, betas=(opt.beta1, 0.999)) 160 | else: 161 | optimizerD = optim.RMSprop(netD.parameters(), lr = opt.lrD) 162 | optimizerG = optim.RMSprop(netG.parameters(), lr = opt.lrG) 163 | 164 | gen_iterations = 0 165 | for epoch in range(opt.niter): 166 | data_iter = iter(dataloader) 167 | i = 0 168 | while i < len(dataloader): 169 | ############################ 170 | # (1) Update D network 171 | ########################### 172 | for p in netD.parameters(): # reset requires_grad 173 | p.requires_grad = True # they are set to False below in netG update 174 | 175 | # train the discriminator Diters times 176 | if gen_iterations < 25 or gen_iterations % 500 == 0: 177 | Diters = 100 178 | else: 179 | Diters = opt.Diters 180 | j = 0 181 | while j < Diters and i < len(dataloader): 182 | j += 1 183 | 184 | # clamp parameters to a cube 185 | for p in netD.parameters(): 186 | p.data.clamp_(opt.clamp_lower, opt.clamp_upper) 187 | 188 | data = data_iter.next() 189 | i += 1 190 | 191 | # train with real 192 | real_cpu, _ = data 193 | netD.zero_grad() 194 | batch_size = real_cpu.size(0) 195 | 196 | if opt.cuda: 197 | real_cpu = real_cpu.cuda() 198 | input.resize_as_(real_cpu).copy_(real_cpu) 199 | inputv = Variable(input) 200 | 201 | errD_real = netD(inputv) 202 | errD_real.backward(one) 203 | 204 | # train with fake 205 | noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1) 206 | noisev = Variable(noise, volatile = True) # totally freeze netG 207 | fake = Variable(netG(noisev).data) 208 | inputv = fake 209 | errD_fake = netD(inputv) 210 | errD_fake.backward(mone) 211 | errD = errD_real - errD_fake 212 | optimizerD.step() 213 | 214 | ############################ 215 | # (2) Update G network 216 | ########################### 217 | for p in netD.parameters(): 218 | p.requires_grad = False # to avoid computation 219 | netG.zero_grad() 220 | # in case our last batch was the tail batch of the dataloader, 221 | # make sure we feed a full batch of noise 222 | noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1) 223 | noisev = Variable(noise) 224 | fake = netG(noisev) 225 | errG = netD(fake) 226 | errG.backward(one) 227 | optimizerG.step() 228 | gen_iterations += 1 229 | 230 | print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f' 231 | % (epoch, opt.niter, i, len(dataloader), gen_iterations, 232 | errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0])) 233 | if gen_iterations % 500 == 0: 234 | real_cpu = real_cpu.mul(0.5).add(0.5) 235 | vutils.save_image(real_cpu, '{0}/real_samples.png'.format(opt.experiment)) 236 | fake = netG(Variable(fixed_noise, volatile=True)) 237 | fake.data = fake.data.mul(0.5).add(0.5) 238 | vutils.save_image(fake.data, '{0}/fake_samples_{1}.png'.format(opt.experiment, gen_iterations)) 239 | 240 | # do checkpointing 241 | torch.save(netG.state_dict(), '{0}/netG_epoch_{1}.pth'.format(opt.experiment, epoch)) 242 | torch.save(netD.state_dict(), '{0}/netD_epoch_{1}.pth'.format(opt.experiment, epoch)) 243 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinarjovsky/WassersteinGAN/f7a01e82007ea408647c451b9e1c8f1932a3db67/models/__init__.py -------------------------------------------------------------------------------- /models/dcgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | 5 | class DCGAN_D(nn.Module): 6 | def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0): 7 | super(DCGAN_D, self).__init__() 8 | self.ngpu = ngpu 9 | assert isize % 16 == 0, "isize has to be a multiple of 16" 10 | 11 | main = nn.Sequential() 12 | # input is nc x isize x isize 13 | main.add_module('initial:{0}-{1}:conv'.format(nc, ndf), 14 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)) 15 | main.add_module('initial:{0}:relu'.format(ndf), 16 | nn.LeakyReLU(0.2, inplace=True)) 17 | csize, cndf = isize / 2, ndf 18 | 19 | # Extra layers 20 | for t in range(n_extra_layers): 21 | main.add_module('extra-layers-{0}:{1}:conv'.format(t, cndf), 22 | nn.Conv2d(cndf, cndf, 3, 1, 1, bias=False)) 23 | main.add_module('extra-layers-{0}:{1}:batchnorm'.format(t, cndf), 24 | nn.BatchNorm2d(cndf)) 25 | main.add_module('extra-layers-{0}:{1}:relu'.format(t, cndf), 26 | nn.LeakyReLU(0.2, inplace=True)) 27 | 28 | while csize > 4: 29 | in_feat = cndf 30 | out_feat = cndf * 2 31 | main.add_module('pyramid:{0}-{1}:conv'.format(in_feat, out_feat), 32 | nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False)) 33 | main.add_module('pyramid:{0}:batchnorm'.format(out_feat), 34 | nn.BatchNorm2d(out_feat)) 35 | main.add_module('pyramid:{0}:relu'.format(out_feat), 36 | nn.LeakyReLU(0.2, inplace=True)) 37 | cndf = cndf * 2 38 | csize = csize / 2 39 | 40 | # state size. K x 4 x 4 41 | main.add_module('final:{0}-{1}:conv'.format(cndf, 1), 42 | nn.Conv2d(cndf, 1, 4, 1, 0, bias=False)) 43 | self.main = main 44 | 45 | 46 | def forward(self, input): 47 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 48 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 49 | else: 50 | output = self.main(input) 51 | 52 | output = output.mean(0) 53 | return output.view(1) 54 | 55 | class DCGAN_G(nn.Module): 56 | def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0): 57 | super(DCGAN_G, self).__init__() 58 | self.ngpu = ngpu 59 | assert isize % 16 == 0, "isize has to be a multiple of 16" 60 | 61 | cngf, tisize = ngf//2, 4 62 | while tisize != isize: 63 | cngf = cngf * 2 64 | tisize = tisize * 2 65 | 66 | main = nn.Sequential() 67 | # input is Z, going into a convolution 68 | main.add_module('initial:{0}-{1}:convt'.format(nz, cngf), 69 | nn.ConvTranspose2d(nz, cngf, 4, 1, 0, bias=False)) 70 | main.add_module('initial:{0}:batchnorm'.format(cngf), 71 | nn.BatchNorm2d(cngf)) 72 | main.add_module('initial:{0}:relu'.format(cngf), 73 | nn.ReLU(True)) 74 | 75 | csize, cndf = 4, cngf 76 | while csize < isize//2: 77 | main.add_module('pyramid:{0}-{1}:convt'.format(cngf, cngf//2), 78 | nn.ConvTranspose2d(cngf, cngf//2, 4, 2, 1, bias=False)) 79 | main.add_module('pyramid:{0}:batchnorm'.format(cngf//2), 80 | nn.BatchNorm2d(cngf//2)) 81 | main.add_module('pyramid:{0}:relu'.format(cngf//2), 82 | nn.ReLU(True)) 83 | cngf = cngf // 2 84 | csize = csize * 2 85 | 86 | # Extra layers 87 | for t in range(n_extra_layers): 88 | main.add_module('extra-layers-{0}:{1}:conv'.format(t, cngf), 89 | nn.Conv2d(cngf, cngf, 3, 1, 1, bias=False)) 90 | main.add_module('extra-layers-{0}:{1}:batchnorm'.format(t, cngf), 91 | nn.BatchNorm2d(cngf)) 92 | main.add_module('extra-layers-{0}:{1}:relu'.format(t, cngf), 93 | nn.ReLU(True)) 94 | 95 | main.add_module('final:{0}-{1}:convt'.format(cngf, nc), 96 | nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False)) 97 | main.add_module('final:{0}:tanh'.format(nc), 98 | nn.Tanh()) 99 | self.main = main 100 | 101 | def forward(self, input): 102 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 103 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 104 | else: 105 | output = self.main(input) 106 | return output 107 | ############################################################################### 108 | class DCGAN_D_nobn(nn.Module): 109 | def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0): 110 | super(DCGAN_D_nobn, self).__init__() 111 | self.ngpu = ngpu 112 | assert isize % 16 == 0, "isize has to be a multiple of 16" 113 | 114 | main = nn.Sequential() 115 | # input is nc x isize x isize 116 | # input is nc x isize x isize 117 | main.add_module('initial:{0}-{1}:conv'.format(nc, ndf), 118 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)) 119 | main.add_module('initial:{0}:conv'.format(ndf), 120 | nn.LeakyReLU(0.2, inplace=True)) 121 | csize, cndf = isize / 2, ndf 122 | 123 | # Extra layers 124 | for t in range(n_extra_layers): 125 | main.add_module('extra-layers-{0}:{1}:conv'.format(t, cndf), 126 | nn.Conv2d(cndf, cndf, 3, 1, 1, bias=False)) 127 | main.add_module('extra-layers-{0}:{1}:relu'.format(t, cndf), 128 | nn.LeakyReLU(0.2, inplace=True)) 129 | 130 | while csize > 4: 131 | in_feat = cndf 132 | out_feat = cndf * 2 133 | main.add_module('pyramid:{0}-{1}:conv'.format(in_feat, out_feat), 134 | nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False)) 135 | main.add_module('pyramid:{0}:relu'.format(out_feat), 136 | nn.LeakyReLU(0.2, inplace=True)) 137 | cndf = cndf * 2 138 | csize = csize / 2 139 | 140 | # state size. K x 4 x 4 141 | main.add_module('final:{0}-{1}:conv'.format(cndf, 1), 142 | nn.Conv2d(cndf, 1, 4, 1, 0, bias=False)) 143 | self.main = main 144 | 145 | 146 | def forward(self, input): 147 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 148 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 149 | else: 150 | output = self.main(input) 151 | 152 | output = output.mean(0) 153 | return output.view(1) 154 | 155 | class DCGAN_G_nobn(nn.Module): 156 | def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0): 157 | super(DCGAN_G_nobn, self).__init__() 158 | self.ngpu = ngpu 159 | assert isize % 16 == 0, "isize has to be a multiple of 16" 160 | 161 | cngf, tisize = ngf//2, 4 162 | while tisize != isize: 163 | cngf = cngf * 2 164 | tisize = tisize * 2 165 | 166 | main = nn.Sequential() 167 | main.add_module('initial:{0}-{1}:convt'.format(nz, cngf), 168 | nn.ConvTranspose2d(nz, cngf, 4, 1, 0, bias=False)) 169 | main.add_module('initial:{0}:relu'.format(cngf), 170 | nn.ReLU(True)) 171 | 172 | csize, cndf = 4, cngf 173 | while csize < isize//2: 174 | main.add_module('pyramid:{0}-{1}:convt'.format(cngf, cngf//2), 175 | nn.ConvTranspose2d(cngf, cngf//2, 4, 2, 1, bias=False)) 176 | main.add_module('pyramid:{0}:relu'.format(cngf//2), 177 | nn.ReLU(True)) 178 | cngf = cngf // 2 179 | csize = csize * 2 180 | 181 | # Extra layers 182 | for t in range(n_extra_layers): 183 | main.add_module('extra-layers-{0}:{1}:conv'.format(t, cngf), 184 | nn.Conv2d(cngf, cngf, 3, 1, 1, bias=False)) 185 | main.add_module('extra-layers-{0}:{1}:relu'.format(t, cngf), 186 | nn.ReLU(True)) 187 | 188 | main.add_module('final:{0}-{1}:convt'.format(cngf, nc), 189 | nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False)) 190 | main.add_module('final:{0}:tanh'.format(nc), 191 | nn.Tanh()) 192 | self.main = main 193 | 194 | def forward(self, input): 195 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 196 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 197 | else: 198 | output = self.main(input) 199 | return output 200 | -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import torch 6 | import torch.nn as nn 7 | 8 | class MLP_G(nn.Module): 9 | def __init__(self, isize, nz, nc, ngf, ngpu): 10 | super(MLP_G, self).__init__() 11 | self.ngpu = ngpu 12 | 13 | main = nn.Sequential( 14 | # Z goes into a linear of size: ngf 15 | nn.Linear(nz, ngf), 16 | nn.ReLU(True), 17 | nn.Linear(ngf, ngf), 18 | nn.ReLU(True), 19 | nn.Linear(ngf, ngf), 20 | nn.ReLU(True), 21 | nn.Linear(ngf, nc * isize * isize), 22 | ) 23 | self.main = main 24 | self.nc = nc 25 | self.isize = isize 26 | self.nz = nz 27 | 28 | def forward(self, input): 29 | input = input.view(input.size(0), input.size(1)) 30 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 31 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 32 | else: 33 | output = self.main(input) 34 | return output.view(output.size(0), self.nc, self.isize, self.isize) 35 | 36 | 37 | class MLP_D(nn.Module): 38 | def __init__(self, isize, nz, nc, ndf, ngpu): 39 | super(MLP_D, self).__init__() 40 | self.ngpu = ngpu 41 | 42 | main = nn.Sequential( 43 | # Z goes into a linear of size: ndf 44 | nn.Linear(nc * isize * isize, ndf), 45 | nn.ReLU(True), 46 | nn.Linear(ndf, ndf), 47 | nn.ReLU(True), 48 | nn.Linear(ndf, ndf), 49 | nn.ReLU(True), 50 | nn.Linear(ndf, 1), 51 | ) 52 | self.main = main 53 | self.nc = nc 54 | self.isize = isize 55 | self.nz = nz 56 | 57 | def forward(self, input): 58 | input = input.view(input.size(0), 59 | input.size(1) * input.size(2) * input.size(3)) 60 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 61 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 62 | else: 63 | output = self.main(input) 64 | output = output.mean(0) 65 | return output.view(1) 66 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | --------------------------------------------------------------------------------