├── .gitignore ├── LICENSE.txt ├── README.md ├── loss_functions.py ├── main.py ├── models.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Added by ~ cadene ~ 2 | .DS_Store 3 | ._.DS_Store 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | 105 | *.pth.tar 106 | .idea 107 | results 108 | dataset -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, Evgenii Zheltonozhskii 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Improving the Improved Training of Wasserstein GANs: A Consistency Term and Its Dual Effect 2 | ===================================== 3 | 4 | Implementation of 5 | ["Improving the Improved Training of Wasserstein GANs: A Consistency Term and Its Dual Effect"](https://arxiv.org/abs/1803.01541) in pytorch. -------------------------------------------------------------------------------- /loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | 4 | 5 | def gradient_penalty(fake_data, real_data, discriminator): 6 | alpha = torch.cuda.FloatTensor(fake_data.shape[0], 1, 1, 1).uniform_(0, 1).expand(fake_data.shape) 7 | interpolates = alpha * fake_data + (1 - alpha) * real_data 8 | interpolates.requires_grad = True 9 | disc_interpolates, _ = discriminator(interpolates) 10 | 11 | gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, 12 | grad_outputs=torch.ones(disc_interpolates.size()).cuda(), 13 | create_graph=True, retain_graph=True, only_inputs=True)[0] 14 | gradients = gradients.view(gradients.size(0), -1) 15 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() 16 | return gradient_penalty 17 | 18 | 19 | def consistency_term(real_data, discriminator, Mtag=0): 20 | d1, d_1 = discriminator(real_data) 21 | d2, d_2 = discriminator(real_data) 22 | 23 | # why max is needed when norm is positive? 24 | consistency_term = (d1 - d2).norm(2, dim=1) + 0.1 * (d_1 - d_2).norm(2, dim=1) - Mtag 25 | return consistency_term.mean() 26 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from datetime import datetime 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.nn.parallel 8 | import torch.optim 9 | import torch.utils.data 10 | import torchvision.datasets as datasets 11 | import torchvision.transforms as transforms 12 | import torchvision.utils as vutils 13 | from tqdm import tqdm, trange 14 | 15 | import loss_functions 16 | import models 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch Wasserstein GAN Training') 19 | 20 | parser.add_argument('--results_dir', metavar='RESULTS_DIR', default='./results', help='results dir') 21 | parser.add_argument('--save', metavar='SAVE', default='', help='saved folder') 22 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 23 | parser.add_argument('-e', '--evaluate', type=str, metavar='FILE', help='evaluate model FILE on validation set') 24 | 25 | parser.add_argument('--dataset', metavar='DATASET', default='celeba', help='dataset name') 26 | parser.add_argument('--dataset-path', metavar='DATASET_PATH', default='./dataset', help='dataset folder') 27 | 28 | parser.add_argument('--input-size', type=int, default=64, help='image input size') 29 | parser.add_argument('--channels', type=int, default=3, help='input image channels') 30 | parser.add_argument('--z-size', type=int, default=128, help='size of the latent z vector') 31 | parser.add_argument('--gen-filters', type=int, default=64) 32 | parser.add_argument('--disc-filters', type=int, default=64) 33 | parser.add_argument('--lambda1', type=float, default=10, help='Gradient penalty multiplier') 34 | parser.add_argument('--lambda2', type=float, default=2, help='Gradient penalty multiplier') 35 | parser.add_argument('--Mtag', type=float, default=0, help='Gradient penalty multiplier') 36 | parser.add_argument('--disc-iters', type=int, default=5, 37 | help='number of discriminator iterations per each generator iteration') 38 | parser.add_argument('--n_extra_layers', type=int, default=0, 39 | help='Number of extra layers for generator and discriminator') 40 | 41 | parser.add_argument('--gpus', default='0', help='gpus used for training - e.g 0,1,3') 42 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 43 | help='number of data loading workers (default: 4)') 44 | 45 | parser.add_argument('--epochs', default=1500, type=int, metavar='N', help='number of total epochs to run') 46 | parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') 47 | 48 | parser.add_argument('--lrg', '--learning-rate-gen', default=2e-4, type=float, metavar='LRG', 49 | help='initial learning rate for generator') 50 | parser.add_argument('--lrd', '--learning-rate-disc', default=2e-4, type=float, metavar='LRD', 51 | help='initial learning rate for discriminator') 52 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 53 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', 54 | help='weight decay (default: 1e-4)') 55 | parser.add_argument('--print-freq', '-p', default=200, type=int, metavar='N', help='print frequency (default: 50)') 56 | 57 | parser.add_argument('--seed', default=42, type=int, help='random seed (default: 42)') 58 | 59 | 60 | def main(): 61 | args = parser.parse_args() 62 | 63 | # random.seed(args.manualSeed) 64 | torch.manual_seed(args.seed) 65 | 66 | time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 67 | # if args.evaluate: 68 | # args.results_dir = '/tmp' 69 | if args.save is '': 70 | args.save = time_stamp 71 | save_path = os.path.join(args.results_dir, args.save) 72 | if not os.path.exists(save_path): 73 | os.makedirs(save_path) 74 | if not os.path.exists(args.dataset_path): 75 | os.makedirs(args.dataset_path) 76 | 77 | if args.dataset == 'cifar10': 78 | dataset = datasets.CIFAR10(root=args.dataset_path, download=True, 79 | transform=transforms.Compose([ 80 | transforms.Resize(args.input_size), 81 | transforms.ToTensor(), 82 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 83 | ])) 84 | args.channels = 3 85 | elif args.dataset == 'mnist': 86 | dataset = datasets.MNIST(root=args.dataset_path, download=True, 87 | transform=transforms.Compose([ 88 | transforms.Resize(args.input_size), 89 | transforms.ToTensor(), 90 | transforms.Normalize((0.5,), (0.5,)), 91 | ])) 92 | args.channels = 1 93 | else: 94 | # folder dataset 95 | dataset = datasets.ImageFolder(root=args.dataset_path, 96 | transform=transforms.Compose([ 97 | transforms.Resize(args.input_size), 98 | transforms.CenterCrop(args.input_size), 99 | transforms.ToTensor(), 100 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 101 | ])) 102 | args.channels = 3 103 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, 104 | num_workers=args.workers) 105 | 106 | netG = models.WGANGenerator(output_size=args.input_size, out_channels=args.channels, num_filters=args.gen_filters) 107 | print(netG) 108 | netD = models.WGANDiscriminator(input_size=args.input_size, in_channels=args.channels, 109 | num_filters=args.disc_filters) 110 | print(netD) 111 | 112 | fixed_noise = torch.FloatTensor(args.batch_size, args.z_size, 1, 1).normal_(0, 1) 113 | 114 | args.gpus = [int(i) for i in args.gpus.split(',')] 115 | torch.cuda.set_device(args.gpus[0]) 116 | cudnn.benchmark = True 117 | 118 | netG = torch.nn.DataParallel(netG, args.gpus).cuda() 119 | netD = torch.nn.DataParallel(netD, args.gpus).cuda() 120 | # optionally resume from a checkpoint 121 | if args.evaluate: 122 | if not os.path.isfile(args.evaluate): 123 | print('invalid checkpoint: {}'.format(args.evaluate)) 124 | return 125 | checkpoint = torch.load(args.evaluate) 126 | netG.load_state_dict(checkpoint['d_state_dict']) 127 | netG.load_state_dict(checkpoint['g_state_dict']) 128 | elif args.resume: 129 | checkpoint_file = args.resume 130 | if os.path.isdir(checkpoint_file): 131 | checkpoint_file = os.path.join(checkpoint_file, 'checkpoint.pth') 132 | if not os.path.isfile(checkpoint_file): 133 | print('invalid checkpoint: {}'.format(args.evaluate)) 134 | return 135 | print("loading checkpoint '{}'".format(args.resume)) 136 | checkpoint = torch.load(checkpoint_file) 137 | args.start_epoch = checkpoint['epoch'] 138 | netD.load_state_dict(checkpoint['d_state_dict']) 139 | netG.load_state_dict(checkpoint['g_state_dict']) 140 | print("loaded checkpoint '{}' (epoch {})".format(checkpoint_file, checkpoint['epoch'])) 141 | 142 | optimizerD = torch.optim.Adam(netD.parameters(), lr=args.lrd) 143 | optimizerG = torch.optim.Adam(netG.parameters(), lr=args.lrg) 144 | 145 | for epoch in trange(args.epochs): 146 | disc_iter = args.disc_iters 147 | # disc_iter = 200 if epoch == 0 else args.disc_iters 148 | for i, (real_inputs, _) in enumerate(tqdm(dataloader)): 149 | # Train Discriminator 150 | for p in netD.parameters(): 151 | p.requires_grad = True 152 | optimizerD.zero_grad() 153 | 154 | # real data 155 | real_inputs = real_inputs.cuda() 156 | errD_real, _ = netD(real_inputs) 157 | errD_real = -errD_real 158 | 159 | # fake data 160 | noise = torch.FloatTensor(real_inputs.shape[0], args.z_size, 1, 1).normal_(0, 1) 161 | with torch.no_grad(): 162 | fake_inputs = netG(noise) 163 | 164 | errD_fake, _ = netD(fake_inputs) 165 | 166 | gp = loss_functions.gradient_penalty(fake_inputs.data, real_inputs.data, netD) 167 | ct = loss_functions.consistency_term(real_inputs, netD, args.Mtag) 168 | errD = errD_real.mean(0) + errD_fake.mean(0) + args.lambda1 * gp + args.lambda2 * ct 169 | errD.backward() 170 | optimizerD.step() 171 | 172 | # Train Generator 173 | if i % disc_iter == 0: 174 | for p in netD.parameters(): 175 | p.requires_grad = False 176 | optimizerG.zero_grad() 177 | noise = torch.FloatTensor(args.batch_size, args.z_size, 1, 1).normal_(0, 1) 178 | fake = netG(noise) 179 | errG, _ = netD(fake) 180 | errG = -errG.mean(0) 181 | errG.backward() 182 | optimizerG.step() 183 | 184 | if i % args.print_freq == 0: 185 | tqdm.write('[{}/{}][{}/{}] Loss_D: {} Loss_G: {} ' 186 | 'Loss_D_real: {} Loss_D_fake {}'.format(epoch, args.epochs, i, len(dataloader), 187 | errD.data.mean(), errG.data.mean(), 188 | errD_real.data.mean(), errD_fake.data.mean())) 189 | 190 | real_inputs = real_inputs.mul(0.5).add(0.5) 191 | vutils.save_image(real_inputs.data, '{0}/real_samples.png'.format(save_path)) 192 | with torch.no_grad(): 193 | fake = netG(fixed_noise) 194 | fake.data = fake.data.mul(0.5).add(0.5) 195 | vutils.save_image(fake.data, '{0}/fake_samples_{1}.png'.format(save_path, epoch)) 196 | 197 | torch.save({ 198 | 'epoch': epoch, 199 | 'd_state_dict': netD.state_dict(), 200 | 'g_state_dict': netG.state_dict(), }, 201 | os.path.join(save_path, 'checkpoint.pth')) 202 | 203 | 204 | if __name__ == '__main__': 205 | main() 206 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.nn import init 6 | 7 | 8 | class WGAN(nn.Module): 9 | def __init__(self): 10 | super(WGAN, self).__init__() 11 | 12 | def init_params(self): 13 | for m in self.modules(): 14 | if isinstance(m, nn.Conv2d): 15 | init.normal_(m.weight, std=0.02) 16 | if m.bias is not None: 17 | init.constant_(m.bias, 0) 18 | elif isinstance(m, nn.BatchNorm2d): 19 | init.constant_(m.weight, 1) 20 | init.constant_(m.bias, 0) 21 | elif isinstance(m, nn.Linear): 22 | init.normal_(m.weight, std=0.02) 23 | if m.bias is not None: 24 | init.constant_(m.bias, 0) 25 | 26 | def _make_extra(self, layer_type, num_filters, n_extra_layers, drop=False, dropout=0): 27 | modules = OrderedDict() 28 | stage_name = "ExtraLayers" 29 | 30 | # Extra layers 31 | for i in range(n_extra_layers): 32 | name = stage_name + "_{}".format(i + 1) 33 | if drop: 34 | module = nn.Sequential( 35 | layer_type(num_filters, num_filters, kernel_size=3, stride=1, padding=1, bias=False), 36 | nn.BatchNorm2d(num_filters), 37 | self.activation(inplace=True), 38 | nn.Dropout(p=dropout, inplace=True)) 39 | else: 40 | module = nn.Sequential( 41 | layer_type(num_filters, num_filters, kernel_size=3, stride=1, padding=1, bias=False), 42 | nn.BatchNorm2d(num_filters), 43 | self.activation(inplace=True)) 44 | modules[name] = module 45 | 46 | return nn.Sequential(modules) 47 | 48 | 49 | class WGANGenerator(WGAN): 50 | def __init__(self, input_size=128, num_filters=64, out_channels=3, output_size=32, n_extra_layers=0, 51 | activation=nn.ReLU): 52 | super(WGANGenerator, self).__init__() 53 | self.activation = activation 54 | self.has_extra = n_extra_layers > 0 55 | self.out_size = output_size 56 | 57 | first_filter = (self.out_size // 8) * num_filters 58 | self.first_conv_trans = nn.Sequential( 59 | nn.ConvTranspose2d(input_size, first_filter, kernel_size=4, stride=1, padding=0, bias=False), 60 | nn.BatchNorm2d(first_filter), 61 | activation(inplace=True)) 62 | 63 | self.conv_trans = self._make_conv_trans(first_filter) 64 | 65 | if self.has_extra: 66 | self.extra = self._make_extra(nn.ConvTranspose2d, num_filters, n_extra_layers) 67 | self.final_conv_trans = nn.Sequential( 68 | nn.ConvTranspose2d(num_filters, out_channels, kernel_size=5, stride=2, padding=2, output_padding=1, 69 | bias=False), 70 | nn.Tanh()) 71 | 72 | self.init_params() 73 | 74 | def forward(self, x): 75 | x = self.first_conv_trans(x) 76 | x = self.conv_trans(x) 77 | if self.has_extra: 78 | x = self.extra(x) 79 | x = self.final_conv_trans(x) 80 | return x 81 | 82 | def _make_conv_trans(self, num_filters): 83 | modules = OrderedDict() 84 | stage_name = "ConvTranspose" 85 | 86 | # ConvTranspose layers 87 | for i in range(int(np.log2(self.out_size)) - 3): 88 | name = stage_name + "_{}".format(i + 1) 89 | module = nn.Sequential( 90 | nn.ConvTranspose2d(num_filters, num_filters // 2, kernel_size=5, stride=2, padding=2, output_padding=1, 91 | bias=False), 92 | nn.BatchNorm2d(num_filters // 2), 93 | self.activation(inplace=True)) 94 | num_filters //= 2 95 | modules[name] = module 96 | 97 | return nn.Sequential(modules) 98 | 99 | 100 | class WGANDiscriminator(WGAN): 101 | def __init__(self, input_size=32, in_channels=3, num_filters=64, 102 | n_extra_layers=0, activation=nn.LeakyReLU, dropout_rate=0.5): 103 | super(WGANDiscriminator, self).__init__() 104 | self.activation = activation 105 | self.has_extra = n_extra_layers > 0 106 | self.input_size = input_size 107 | self.dropout_rate = dropout_rate 108 | 109 | self.first_conv = nn.Sequential(nn.Conv2d(in_channels, num_filters, kernel_size=5, stride=2, padding=2), 110 | activation(), 111 | nn.Dropout(p=dropout_rate)) 112 | if self.has_extra: 113 | self.extra = self._make_extra(nn.Conv2d, num_filters, n_extra_layers, drop=True, dropout=dropout_rate) 114 | 115 | self.conv, last_filter = self._make_conv(num_filters) 116 | self.final_conv = nn.Conv2d(last_filter, 1, kernel_size=4, stride=1, padding=0, bias=False) 117 | 118 | self.init_params() 119 | 120 | def forward(self, x): 121 | x = self.first_conv(x) 122 | if self.has_extra: 123 | x = self.extra(x) 124 | pre_last = self.conv(x) 125 | x = self.final_conv(pre_last) 126 | return x, pre_last 127 | 128 | def _make_conv(self, num_filters): 129 | modules = OrderedDict() 130 | stage_name = "Conv" 131 | 132 | # Conv layers 133 | for i in range(int(np.log2(self.input_size)) - 3): 134 | name = stage_name + "_{}".format(i + 1) 135 | module = nn.Sequential( 136 | nn.Conv2d(num_filters, num_filters * 2, kernel_size=5, stride=2, padding=2), 137 | self.activation(), 138 | nn.Dropout(p=self.dropout_rate)) 139 | num_filters *= 2 140 | modules[name] = module 141 | 142 | return nn.Sequential(modules), num_filters 143 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0.0 2 | torchvision>=0.2.0 3 | tqdm>=4.19.4 --------------------------------------------------------------------------------