├── Code ├── GAN.py ├── create_FID_stats.py ├── experiments.sh ├── fid.py ├── fid_new.py ├── fid_script.sh ├── preprocess_cat_dataset.py ├── pytorch_visualize.py ├── setting_up_script.sh └── startup_tmp.sh ├── LICENSE └── README.md /Code/GAN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # To get TensorBoard output, use the python command: tensorboard --logdir /home/alexia/Output/DCGAN 4 | # TensorBoard disabled for now. 5 | 6 | # To get CIFAR10 7 | # wget http://pjreddie.com/media/files/cifar.tgz 8 | # tar xzf cifar.tgz 9 | 10 | 11 | ## Parameters 12 | 13 | # thanks https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 14 | def strToBool(str): 15 | return str.lower() in ('true', 'yes', 'on', 't', '1') 16 | 17 | import argparse 18 | parser = argparse.ArgumentParser() 19 | parser.register('type', 'bool', strToBool) 20 | parser.add_argument('--image_size', type=int, default=64) 21 | parser.add_argument('--batch_size', type=int, default=32) # DCGAN paper original value used 128 (32 is generally better to prevent vanishing gradients with SGAN and LSGAN, not important with relativistic GANs) 22 | parser.add_argument('--n_colors', type=int, default=3) 23 | parser.add_argument('--z_size', type=int, default=128) 24 | parser.add_argument('--G_h_size', type=int, default=128, help='Number of hidden nodes in the Generator. Used only in arch=0. Too small leads to bad results, too big blows up the GPU RAM.') # DCGAN paper original value 25 | parser.add_argument('--D_h_size', type=int, default=128, help='Number of hidden nodes in the Discriminator. Used only in arch=0. Too small leads to bad results, too big blows up the GPU RAM.') # DCGAN paper original value 26 | parser.add_argument('--conv_size', type=int, default=64, help='Size of convolutions when using Self-Attention GAN.') 27 | parser.add_argument('--resample', type=int, default=1, help="Resample data in the generator step (Recommended, may affect performance sightly)") 28 | parser.add_argument('--centercrop', type=int, default=0, help="If not 0, CenterCrop with specified number the images") 29 | parser.add_argument('--lr_D', type=float, default=.0001, help='Discriminator learning rate') 30 | parser.add_argument('--lr_G', type=float, default=.0001, help='Generator learning rate') 31 | parser.add_argument('--n_iter', type=int, default=100000, help='Number of iteration cycles') 32 | parser.add_argument('--beta1', type=float, default=0.5, help='Adam betas[0], DCGAN paper recommends .50 instead of the usual .90') 33 | parser.add_argument('--beta2', type=float, default=0.999, help='Adam betas[1]') 34 | parser.add_argument('--decay', type=float, default=0, help='Decay to apply to lr each cycle. decay^n_iter gives the final lr. Ex: .00002 will lead to .13 of lr after 100k cycles') 35 | parser.add_argument('--SELU', type='bool', default=False, help='Using scaled exponential linear units (SELU) which are self-normalizing instead of ReLU with BatchNorm. Used only in arch=0. This improves stability.') 36 | parser.add_argument("--NN_conv", type='bool', default=False, help="This approach minimize checkerboard artifacts during training. Used only by arch=0. Uses nearest-neighbor resized convolutions instead of strided convolutions (https://distill.pub/2016/deconv-checkerboard/ and github.com/abhiskk/fast-neural-style).") 37 | parser.add_argument('--seed', type=int) 38 | parser.add_argument('--input_folder', default='/Datasets/Meow_32x32', help='input folder') 39 | parser.add_argument('--output_folder', default='/Output/GANlosses', help='output folder') 40 | parser.add_argument('--load', default=None, help='Full path to network state to load (ex: /network/home/output_folder/run-5/models/state_11.pth)') 41 | parser.add_argument('--cuda', type='bool', default=True, help='enables cuda') 42 | parser.add_argument('--n_gpu', type=int, default=1, help='number of GPUs to use') 43 | parser.add_argument('--loss_D', type=int, default=1, help='Loss of D, see code for details') 44 | parser.add_argument('--Diters', type=int, default=1, help='Number of iterations of D') 45 | parser.add_argument('--Giters', type=int, default=1, help='Number of iterations of G.') 46 | parser.add_argument('--spectral', type='bool', default=False, help='If True, use spectral normalization to make the discriminator Lipschitz. This Will also remove batch norm in the discriminator.') 47 | parser.add_argument('--spectral_G', type='bool', default=False, help='If True, use spectral normalization to make the generator Lipschitz (Generally only D is spectral, not G). This Will also remove batch norm in the discriminator.') 48 | parser.add_argument('--weight_decay', type=float, default=0, help='L2 regularization weight. Helps convergence but leads to artifacts in images, not recommended.') 49 | parser.add_argument('--gen_extra_images', type=int, default=50000, help='Generate additional images with random fake cats in calculating FID (Recommended to use the same amount as the size of the dataset; for CIFAR-10 we use 50k, but most people use 10k) It must be a multiple of 100.') 50 | parser.add_argument('--gen_every', type=int, default=100000, help='Generate additional images with random fake cats every x iterations. Used in calculating FID.') 51 | parser.add_argument('--logs_folder', default='/scratch/jolicoea/Output/Extra', help='Folder for models and FID logs') 52 | parser.add_argument('--extra_folder', default='/Output/Extra', help='Folder for images for FID calculation') 53 | parser.add_argument('--show_graph', type='bool', default=False, help='If True, show gradients graph. Really neat for debugging.') 54 | parser.add_argument('--no_batch_norm_G', type='bool', default=False, help='If True, no batch norm in G.') 55 | parser.add_argument('--no_batch_norm_D', type='bool', default=False, help='If True, no batch norm in D.') 56 | parser.add_argument('--Tanh_GD', type='bool', default=False, help='If True, tanh everywhere.') 57 | parser.add_argument('--arch', type=int, default=0, help='0:DCGAN with number of layers adjusted based on image size, 1: standard CNN for 32x32 images from the Spectral GAN paper. Some options may be ignored by some architectures.') 58 | parser.add_argument('--print_every', type=int, default=1000, help='Generate a mini-batch of images at every x iterations (to see how the training progress, you can do it often).') 59 | parser.add_argument('--save', type='bool', default=False, help='Do we save models, yes or no? It will be saved in extra_folder') 60 | parser.add_argument('--CIFAR10', type='bool', default=False, help='If True, use CIFAR-10 instead of your own dataset. Make sure image_size is set to 32!') 61 | parser.add_argument('--CIFAR10_input_folder', default='/Datasets/CIFAR10', help='input folder (automatically downloaded)') 62 | parser.add_argument('--LSUN', type='bool', default=False, help='If True, use LSUN instead of your own dataset.') 63 | parser.add_argument('--LSUN_input_folder', default='/Datasets/LSUN', help='input folder') 64 | parser.add_argument('--LSUN_classes', nargs='+', default='bedroom_train', help='Classes to use (see https://pytorch.org/docs/stable/torchvision/datasets.html#lsun)') 65 | parser.add_argument("--no_bias", type='bool', default=False, help="Unbiased estimator when using RaLSGAN (loss_D=12) or RcLSGAN (loss_D=22)") 66 | 67 | # Options for Gradient penalties 68 | parser.add_argument('--penalty', type=float, default=20, help='Gradient penalty parameter for Gradien penalties') 69 | parser.add_argument('--grad_penalty', type='bool', default=False, help='If True, use gradient penalty of WGAN-GP but with whichever loss_D chosen. No need to set this true with WGAN-GP.') 70 | parser.add_argument('--no_grad_penalty', type='bool', default=False, help='If True, do not use gradient penalty when using WGAN-GP (If you want to try Spectral normalized WGAN).') 71 | parser.add_argument('--grad_penalty_aug', type='bool', default=False, help='If True, use augmented lagrangian for gradient penalty (aka Sobolev GAN).') 72 | parser.add_argument('--fake_only', type='bool', default=False, help='Using fake data only in gradient penalty') 73 | parser.add_argument('--real_only', type='bool', default=False, help='Using real data only in gradient penalty') 74 | parser.add_argument('--delta', type=float, default=1, help='(||grad_D(x)|| - delta)^2') 75 | parser.add_argument('--rho', type=float, default=.0001, help='learning rate of lagrange multiplier when using augmented lagrangian') 76 | parser.add_argument('--penalty-type', help='Gradient penalty type. The default ("squared-diff") forces gradient norm *equal* to 1, which is not correct, but is what is done in the original WGAN paper. True Lipschitz constraint is with "clamp"', 77 | choices=['clamp','squared-clamp','squared-diff','squared','TV','abs','hinge','hinge2'], default='squared-diff') 78 | parser.add_argument('--reduction', help='Summary statistic for gradient penalty over batches (default: "mean")', 79 | choices=['mean','max','softmax'],default='mean') 80 | 81 | # Max Margin 82 | parser.add_argument('--l1_margin', help='maximize L-1 margin (equivalent to penalizing L-infinity gradient norm)',action='store_true') 83 | parser.add_argument('--l1_margin_logsumexp', help='maximize L-1 margin using logsumexp to approximate L-infinity gradient norm (equivalent to penalizing L-infinity gradient norm)',action='store_true') 84 | parser.add_argument('--l1_margin_smoothmax', help='maximize L-1 margin using smooth max to approximate L-infinity gradient norm (equivalent to penalizing L-infinity gradient norm)',action='store_true') 85 | parser.add_argument('--linf_margin', help='maximize L-infinity margin (equivalent to penalizing L-1 gradient norm)',action='store_true') 86 | parser.add_argument('--smoothmax', type=float, default=.5, help='parameter for smooth max (higher = less smooth)') 87 | parser.add_argument('--l1_margin_no_abs', help='Only penalize positive gradient (Shouldnt work, but it does)',action='store_true') 88 | 89 | param = parser.parse_args() 90 | print('Arguments:') 91 | for p in vars(param).items(): 92 | print(' ',p[0]+': ',p[1]) 93 | print('\n') 94 | 95 | ## Imports 96 | import torch.nn.functional as F 97 | 98 | # Time 99 | import time 100 | import sys 101 | start = time.time() 102 | 103 | # Setting the title for the file saved 104 | if param.loss_D == 1: 105 | title = 'GAN_' 106 | if param.loss_D == 2: 107 | title = 'LSGAN_' 108 | if param.loss_D == 3: 109 | title = 'HingeGAN_' 110 | if param.loss_D == 4: 111 | title = 'WGANGP_' 112 | 113 | if param.loss_D == 11: 114 | title = 'RaSGAN_' 115 | if param.loss_D == 12: 116 | title = 'RaLSGAN_' 117 | if param.loss_D == 13: 118 | title = 'RaHingeGAN_' 119 | 120 | if param.loss_D == 21: 121 | title = 'RcSGAN_' 122 | if param.loss_D == 22: 123 | title = 'RcLSGAN_' 124 | if param.loss_D == 23: 125 | title = 'RcHingeGAN_' 126 | 127 | if param.loss_D == 31: 128 | title = 'RpSGAN_' 129 | if param.loss_D == 32: 130 | title = 'RpLSGAN_' 131 | if param.loss_D == 33: 132 | title = 'RpHingeGAN_' 133 | 134 | if param.loss_D == 41: 135 | title = 'RpSGAN_MVUE_' 136 | if param.loss_D == 42: 137 | title = 'RpLSGAN_MVUE_' 138 | if param.loss_D == 43: 139 | title = 'RpHingeGAN_MVUE_' 140 | 141 | if param.no_bias: 142 | title = title + 'nobias_' 143 | 144 | if param.seed is not None: 145 | title = title + 'seed%i' % param.seed 146 | 147 | 148 | # Check folder run-i for all i=0,1,... until it finds run-j which does not exists, then creates a new folder run-j 149 | import os 150 | 151 | # Add local variable to folders so we work in the local drive (removed for public released) 152 | #param.input_folder = os.environ["SLURM_TMPDIR"] + param.input_folder 153 | #param.output_folder = os.environ["SLURM_TMPDIR"] + param.output_folder 154 | #param.CIFAR10_input_folder = os.environ["SLURM_TMPDIR"] + param.CIFAR10_input_folder 155 | #param.LSUN_input_folder = os.environ["SLURM_TMPDIR"] + param.LSUN_input_folder 156 | #param.extra_folder = os.environ["SLURM_TMPDIR"] + param.extra_folder 157 | 158 | run = 0 159 | base_dir = f"{param.output_folder}/{title}-{run}" 160 | while os.path.exists(base_dir): 161 | run += 1 162 | base_dir = f"{param.output_folder}/{title}-{run}" 163 | os.makedirs(base_dir) 164 | logs_dir = f"{base_dir}/logs" 165 | os.makedirs(logs_dir) 166 | os.makedirs(f"{base_dir}/images") 167 | if param.gen_extra_images > 0 and not os.path.exists(f"{param.extra_folder}"): 168 | os.makedirs(f"{param.extra_folder}") 169 | if param.gen_extra_images > 0 and not os.path.exists(f"{param.logs_folder}"): 170 | os.makedirs(f"{param.logs_folder}") 171 | 172 | # where we save the output 173 | log_output = open(f"{logs_dir}/log.txt", 'w') 174 | print(param, file=log_output) 175 | 176 | import numpy 177 | import torch 178 | import torch.autograd as autograd 179 | from torch.autograd import Variable 180 | 181 | # For plotting the Loss of D and G using tensorboard 182 | # To fix later, not compatible with using tensorflow 183 | #from tensorboard_logger import configure, log_value 184 | #configure(logs_dir, flush_secs=5) 185 | 186 | import torchvision 187 | import torchvision.datasets as dset 188 | import torchvision.transforms as transf 189 | import torchvision.models as models 190 | import torchvision.utils as vutils 191 | import torch.nn.utils.spectral_norm as spectral_norm 192 | 193 | if param.cuda: 194 | import torch.backends.cudnn as cudnn 195 | cudnn.deterministic = True 196 | cudnn.benchmark = True 197 | 198 | # To see images 199 | from IPython.display import Image 200 | to_img = transf.ToPILImage() 201 | 202 | #import pytorch_visualize as pv 203 | 204 | import math 205 | 206 | torch.utils.backcompat.broadcast_warning.enabled=True 207 | 208 | from fid import calculate_fid_given_paths as calc_fid 209 | #from inception import get_inception_score 210 | #from inception import load_images 211 | 212 | ## Setting seed 213 | import random 214 | if param.seed is None: 215 | param.seed = random.randint(1, 10000) 216 | print(f"Random Seed: {param.seed}") 217 | print(f"Random Seed: {param.seed}", file=log_output) 218 | random.seed(param.seed) 219 | numpy.random.seed(param.seed) 220 | torch.manual_seed(param.seed) 221 | if param.cuda: 222 | torch.cuda.manual_seed_all(param.seed) 223 | 224 | ## Transforming images 225 | if param.centercrop != 0: 226 | trans = transf.Compose([ 227 | transf.CenterCrop(param.centercrop), 228 | transf.Resize((param.image_size, param.image_size)), 229 | # This makes it into [0,1] 230 | transf.ToTensor(), 231 | # This makes it into [-1,1] 232 | transf.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5]) 233 | ]) 234 | else: 235 | trans = transf.Compose([ 236 | transf.Resize((param.image_size, param.image_size)), 237 | # This makes it into [0,1] 238 | transf.ToTensor(), 239 | # This makes it into [-1,1] 240 | transf.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5]) 241 | ]) 242 | 243 | ## Importing dataset 244 | if param.CIFAR10: 245 | data = dset.CIFAR10(root=param.CIFAR10_input_folder, train=True, download=True, transform=trans) 246 | elif param.LSUN: 247 | print(param.LSUN_classes) 248 | data = dset.LSUN(root=param.LSUN_input_folder, classes=[param.LSUN_classes], transform=trans) 249 | else: 250 | data = dset.ImageFolder(root=param.input_folder, transform=trans) 251 | 252 | # Loading data randomly 253 | def generate_random_sample(size=param.batch_size): 254 | while True: 255 | random_indexes = numpy.random.choice(data.__len__(), size=size, replace=False) 256 | batch = [data[i][0] for i in random_indexes] 257 | yield torch.stack(batch, 0) 258 | random_sample = generate_random_sample(size=param.batch_size) 259 | 260 | ## Models 261 | 262 | if param.activation=='leaky': 263 | class Activation(torch.nn.Module): 264 | def __init__(self): 265 | super().__init__() 266 | self.act = torch.nn.LeakyReLU(0.1 if param.arch == 1 else .02, inplace=True) 267 | def forward(self, x): 268 | return self.act(x) 269 | elif param.activation=='softplus': 270 | class Activation(torch.nn.Module): 271 | def __init__(self): 272 | super().__init__() 273 | #self.a = torch.nn.Parameter(torch.FloatTensor(1).fill_(1.)) 274 | self.a = torch.nn.Softplus(1, 20.) 275 | def forward(self, x): 276 | #return F.softplus(x, self.a, 20.) 277 | return self.a(x) 278 | 279 | if param.arch == 1: 280 | title = title + '_CNN_' 281 | 282 | class DCGAN_G(torch.nn.Module): 283 | def __init__(self): 284 | super(DCGAN_G, self).__init__() 285 | 286 | self.dense = torch.nn.Linear(param.z_size, 512 * 4 * 4) 287 | 288 | if param.spectral_G: 289 | model = [spectral_norm(torch.nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True))] 290 | model += [torch.nn.ReLU(True), 291 | spectral_norm(torch.nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True))] 292 | model += [torch.nn.ReLU(True), 293 | spectral_norm(torch.nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=True))] 294 | model += [torch.nn.ReLU(True), 295 | spectral_norm(torch.nn.Conv2d(64, param.n_colors, kernel_size=3, stride=1, padding=1, bias=True)), 296 | torch.nn.Tanh()] 297 | else: 298 | model = [torch.nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)] 299 | if not param.no_batch_norm_G: 300 | model += [torch.nn.BatchNorm2d(256)] 301 | if param.Tanh_GD: 302 | model += [torch.nn.Tanh()] 303 | else: 304 | model += [torch.nn.ReLU(True)] 305 | model += [torch.nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True)] 306 | if not param.no_batch_norm_G: 307 | model += [torch.nn.BatchNorm2d(128)] 308 | if param.Tanh_GD: 309 | model += [torch.nn.Tanh()] 310 | else: 311 | model += [torch.nn.ReLU(True)] 312 | model += [torch.nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=True)] 313 | if not param.no_batch_norm_G: 314 | model += [torch.nn.BatchNorm2d(64)] 315 | if param.Tanh_GD: 316 | model += [torch.nn.Tanh()] 317 | else: 318 | model += [torch.nn.ReLU(True)] 319 | model += [torch.nn.Conv2d(64, param.n_colors, kernel_size=3, stride=1, padding=1, bias=True), 320 | torch.nn.Tanh()] 321 | self.model = torch.nn.Sequential(*model) 322 | 323 | def forward(self, input): 324 | if isinstance(input.data, torch.cuda.FloatTensor) and param.n_gpu > 1: 325 | output = torch.nn.parallel.data_parallel(self.model(self.dense(input.view(-1, param.z_size)).view(-1, 512, 4, 4)), input, range(param.n_gpu)) 326 | else: 327 | output = self.model(self.dense(input.view(-1, param.z_size)).view(-1, 512, 4, 4)) 328 | #print(output.size()) 329 | return output 330 | 331 | class DCGAN_D(torch.nn.Module): 332 | def __init__(self): 333 | super(DCGAN_D, self).__init__() 334 | 335 | self.dense = torch.nn.Linear(512 * 4 * 4, 1) 336 | 337 | if param.spectral: 338 | model = [spectral_norm(torch.nn.Conv2d(param.n_colors, 64, kernel_size=3, stride=1, padding=1, bias=True)), 339 | Activation(), 340 | spectral_norm(torch.nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True)), 341 | Activation(), 342 | 343 | spectral_norm(torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True)), 344 | Activation(), 345 | spectral_norm(torch.nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True)), 346 | Activation(), 347 | 348 | spectral_norm(torch.nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True)), 349 | Activation(), 350 | spectral_norm(torch.nn.Conv2d(256, 256, kernel_size=4, stride=2, padding=1, bias=True)), 351 | Activation(), 352 | 353 | spectral_norm(torch.nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True)), 354 | Activation()] 355 | else: 356 | model = [torch.nn.Conv2d(param.n_colors, 64, kernel_size=3, stride=1, padding=1, bias=True)] 357 | if not param.no_batch_norm_D: 358 | model += [torch.nn.BatchNorm2d(64)] 359 | if param.Tanh_GD: 360 | model += [torch.nn.Tanh()] 361 | else: 362 | model += [Activation()] 363 | model += [torch.nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True)] 364 | if not param.no_batch_norm_D: 365 | model += [torch.nn.BatchNorm2d(64)] 366 | if param.Tanh_GD: 367 | model += [torch.nn.Tanh()] 368 | else: 369 | model += [Activation()] 370 | model += [torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True)] 371 | if not param.no_batch_norm_D: 372 | model += [torch.nn.BatchNorm2d(128)] 373 | if param.Tanh_GD: 374 | model += [torch.nn.Tanh()] 375 | else: 376 | model += [Activation()] 377 | model += [torch.nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True)] 378 | if not param.no_batch_norm_D: 379 | model += [torch.nn.BatchNorm2d(128)] 380 | if param.Tanh_GD: 381 | model += [torch.nn.Tanh()] 382 | else: 383 | model += [Activation()] 384 | model += [torch.nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True)] 385 | if not param.no_batch_norm_D: 386 | model += [torch.nn.BatchNorm2d(256)] 387 | if param.Tanh_GD: 388 | model += [torch.nn.Tanh()] 389 | else: 390 | model += [Activation()] 391 | model += [torch.nn.Conv2d(256, 256, kernel_size=4, stride=2, padding=1, bias=True)] 392 | if not param.no_batch_norm_D: 393 | model += [torch.nn.BatchNorm2d(256)] 394 | if param.Tanh_GD: 395 | model += [torch.nn.Tanh()] 396 | else: 397 | model += [Activation()] 398 | model += [torch.nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True)] 399 | if param.Tanh_GD: 400 | model += [torch.nn.Tanh()] 401 | else: 402 | model += [Activation()] 403 | self.model = torch.nn.Sequential(*model) 404 | 405 | def forward(self, input): 406 | if isinstance(input.data, torch.cuda.FloatTensor) and param.n_gpu > 1: 407 | output = torch.nn.parallel.data_parallel(self.dense(self.model(input).view(-1, 512 * 4 * 4)).view(-1), input, range(param.n_gpu)) 408 | else: 409 | output = self.dense(self.model(input).view(-1, 512 * 4 * 4)).view(-1) 410 | #print(output.size()) 411 | return output 412 | 413 | if param.arch == 0: 414 | 415 | # DCGAN generator 416 | class DCGAN_G(torch.nn.Module): 417 | def __init__(self): 418 | super(DCGAN_G, self).__init__() 419 | main = torch.nn.Sequential() 420 | 421 | # We need to know how many layers we will use at the beginning 422 | mult = param.image_size // 8 423 | 424 | ### Start block 425 | # Z_size random numbers 426 | if param.spectral_G: 427 | main.add_module('Start-SpectralConvTranspose2d', torch.nn.utils.spectral_norm(torch.nn.ConvTranspose2d(param.z_size, param.G_h_size * mult, kernel_size=4, stride=1, padding=0, bias=False))) 428 | else: 429 | main.add_module('Start-ConvTranspose2d', torch.nn.ConvTranspose2d(param.z_size, param.G_h_size * mult, kernel_size=4, stride=1, padding=0, bias=False)) 430 | if param.SELU: 431 | main.add_module('Start-SELU', torch.nn.SELU(inplace=True)) 432 | else: 433 | if not param.no_batch_norm_G and not param.spectral_G: 434 | main.add_module('Start-BatchNorm2d', torch.nn.BatchNorm2d(param.G_h_size * mult)) 435 | if param.Tanh_GD: 436 | main.add_module('Start-Tanh', torch.nn.Tanh()) 437 | else: 438 | main.add_module('Start-ReLU', torch.nn.ReLU()) 439 | # Size = (G_h_size * mult) x 4 x 4 440 | 441 | ### Middle block (Done until we reach ? x image_size/2 x image_size/2) 442 | ii = 1 443 | while mult > 1: 444 | if param.NN_conv: 445 | main.add_module('Middle-UpSample [%d]' % ii, torch.nn.Upsample(scale_factor=2)) 446 | if param.spectral_G: 447 | main.add_module('Middle-SpectralConv2d [%d]' % ii, torch.nn.utils.spectral_norm(torch.nn.Conv2d(param.G_h_size * mult, param.G_h_size * (mult//2), kernel_size=3, stride=1, padding=1))) 448 | else: 449 | main.add_module('Middle-Conv2d [%d]' % ii, torch.nn.Conv2d(param.G_h_size * mult, param.G_h_size * (mult//2), kernel_size=3, stride=1, padding=1)) 450 | else: 451 | if param.spectral_G: 452 | main.add_module('Middle-SpectralConvTranspose2d [%d]' % ii, torch.nn.utils.spectral_norm(torch.nn.ConvTranspose2d(param.G_h_size * mult, param.G_h_size * (mult//2), kernel_size=4, stride=2, padding=1, bias=False))) 453 | else: 454 | main.add_module('Middle-ConvTranspose2d [%d]' % ii, torch.nn.ConvTranspose2d(param.G_h_size * mult, param.G_h_size * (mult//2), kernel_size=4, stride=2, padding=1, bias=False)) 455 | if param.SELU: 456 | main.add_module('Middle-SELU [%d]' % ii, torch.nn.SELU(inplace=True)) 457 | else: 458 | if not param.no_batch_norm_G and not param.spectral_G: 459 | main.add_module('Middle-BatchNorm2d [%d]' % ii, torch.nn.BatchNorm2d(param.G_h_size * (mult//2))) 460 | if param.Tanh_GD: 461 | main.add_module('Middle-Tanh [%d]' % ii, torch.nn.Tanh()) 462 | else: 463 | main.add_module('Middle-ReLU [%d]' % ii, torch.nn.ReLU()) 464 | # Size = (G_h_size * (mult/(2*i))) x 8 x 8 465 | mult = mult // 2 466 | ii += 1 467 | 468 | ### End block 469 | # Size = G_h_size x image_size/2 x image_size/2 470 | if param.NN_conv: 471 | main.add_module('End-UpSample', torch.nn.Upsample(scale_factor=2)) 472 | if param.spectral_G: 473 | main.add_module('End-SpectralConv2d', torch.nn.utils.spectral_norm(torch.nn.Conv2d(param.G_h_size, param.n_colors, kernel_size=3, stride=1, padding=1))) 474 | else: 475 | main.add_module('End-Conv2d', torch.nn.Conv2d(param.G_h_size, param.n_colors, kernel_size=3, stride=1, padding=1)) 476 | else: 477 | if param.spectral_G: 478 | main.add_module('End-SpectralConvTranspose2d', torch.nn.utils.spectral_norm(torch.nn.ConvTranspose2d(param.G_h_size, param.n_colors, kernel_size=4, stride=2, padding=1, bias=False))) 479 | else: 480 | main.add_module('End-ConvTranspose2d', torch.nn.ConvTranspose2d(param.G_h_size, param.n_colors, kernel_size=4, stride=2, padding=1, bias=False)) 481 | main.add_module('End-Tanh', torch.nn.Tanh()) 482 | # Size = n_colors x image_size x image_size 483 | self.main = main 484 | 485 | def forward(self, input): 486 | if isinstance(input.data, torch.cuda.FloatTensor) and param.n_gpu > 1: 487 | output = torch.nn.parallel.data_parallel(self.main, input, range(param.n_gpu)) 488 | else: 489 | output = self.main(input) 490 | return output 491 | 492 | # DCGAN discriminator (using somewhat the reverse of the generator) 493 | class DCGAN_D(torch.nn.Module): 494 | def __init__(self): 495 | super(DCGAN_D, self).__init__() 496 | main = torch.nn.Sequential() 497 | 498 | ### Start block 499 | # Size = n_colors x image_size x image_size 500 | if param.spectral: 501 | main.add_module('Start-SpectralConv2d', torch.nn.utils.spectral_norm(torch.nn.Conv2d(param.n_colors, param.D_h_size, kernel_size=4, stride=2, padding=1, bias=False))) 502 | else: 503 | main.add_module('Start-Conv2d', torch.nn.Conv2d(param.n_colors, param.D_h_size, kernel_size=4, stride=2, padding=1, bias=False)) 504 | if param.SELU: 505 | main.add_module('Start-SELU', torch.nn.SELU(inplace=True)) 506 | else: 507 | if param.Tanh_GD: 508 | main.add_module('Start-Tanh', torch.nn.Tanh()) 509 | else: 510 | main.add_module('Start-LeakyReLU', Activation()) 511 | image_size_new = param.image_size // 2 512 | # Size = D_h_size x image_size/2 x image_size/2 513 | 514 | ### Middle block (Done until we reach ? x 4 x 4) 515 | mult = 1 516 | ii = 0 517 | while image_size_new > 4: 518 | if param.spectral: 519 | main.add_module('Middle-SpectralConv2d [%d]' % ii, torch.nn.utils.spectral_norm(torch.nn.Conv2d(param.D_h_size * mult, param.D_h_size * (2*mult), kernel_size=4, stride=2, padding=1, bias=False))) 520 | else: 521 | main.add_module('Middle-Conv2d [%d]' % ii, torch.nn.Conv2d(param.D_h_size * mult, param.D_h_size * (2*mult), kernel_size=4, stride=2, padding=1, bias=False)) 522 | if param.SELU: 523 | main.add_module('Middle-SELU [%d]' % ii, torch.nn.SELU(inplace=True)) 524 | else: 525 | if not param.no_batch_norm_D and not param.spectral: 526 | main.add_module('Middle-BatchNorm2d [%d]' % ii, torch.nn.BatchNorm2d(param.D_h_size * (2*mult))) 527 | if param.Tanh_GD: 528 | main.add_module('Start-Tanh [%d]' % ii, torch.nn.Tanh()) 529 | else: 530 | main.add_module('Middle-LeakyReLU [%d]' % ii, Activation()) 531 | # Size = (D_h_size*(2*i)) x image_size/(2*i) x image_size/(2*i) 532 | image_size_new = image_size_new // 2 533 | mult *= 2 534 | ii += 1 535 | 536 | ### End block 537 | # Size = (D_h_size * mult) x 4 x 4 538 | if param.spectral: 539 | main.add_module('End-SpectralConv2d', torch.nn.utils.spectral_norm(torch.nn.Conv2d(param.D_h_size * mult, 1, kernel_size=4, stride=1, padding=0, bias=False))) 540 | else: 541 | main.add_module('End-Conv2d', torch.nn.Conv2d(param.D_h_size * mult, 1, kernel_size=4, stride=1, padding=0, bias=False)) 542 | # Size = 1 x 1 x 1 (Is a real cat or not?) 543 | self.main = main 544 | 545 | def forward(self, input): 546 | if isinstance(input.data, torch.cuda.FloatTensor) and param.n_gpu > 1: 547 | output = torch.nn.parallel.data_parallel(self.main, input, range(param.n_gpu)) 548 | else: 549 | output = self.main(input) 550 | # Convert from 1 x 1 x 1 to 1 so that we can compare to given label (cat or not?) 551 | return output.view(-1) 552 | 553 | if param.arch == 2: 554 | # Taken directly from https://github.com/ozanciga/gans-with-pytorch/blob/master/wgan-gp/models.py 555 | 556 | class MeanPoolConv(torch.nn.Module): 557 | def __init__(self, n_input, n_output, k_size): 558 | super(MeanPoolConv, self).__init__() 559 | conv1 = torch.nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True) 560 | self.model = torch.nn.Sequential(conv1) 561 | def forward(self, x): 562 | out = (x[:,:,::2,::2] + x[:,:,1::2,::2] + x[:,:,::2,1::2] + x[:,:,1::2,1::2]) / 4.0 563 | out = self.model(out) 564 | return out 565 | 566 | class ConvMeanPool(torch.nn.Module): 567 | def __init__(self, n_input, n_output, k_size): 568 | super(ConvMeanPool, self).__init__() 569 | conv1 = torch.nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True) 570 | self.model = torch.nn.Sequential(conv1) 571 | def forward(self, x): 572 | out = self.model(x) 573 | out = (out[:,:,::2,::2] + out[:,:,1::2,::2] + out[:,:,::2,1::2] + out[:,:,1::2,1::2]) / 4.0 574 | return out 575 | 576 | class UpsampleConv(torch.nn.Module): 577 | def __init__(self, n_input, n_output, k_size): 578 | super(UpsampleConv, self).__init__() 579 | 580 | self.model = torch.nn.Sequential( 581 | torch.nn.PixelShuffle(2), 582 | torch.nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True) 583 | ) 584 | def forward(self, x): 585 | x = x.repeat((1, 4, 1, 1)) # Weird concat of WGAN-GPs upsampling process. 586 | out = self.model(x) 587 | return out 588 | 589 | class ResidualBlock(torch.nn.Module): 590 | def __init__(self, n_input, n_output, k_size, resample='up', bn=True, spatial_dim=None): 591 | super(ResidualBlock, self).__init__() 592 | 593 | self.resample = resample 594 | 595 | if resample == 'up': 596 | self.conv1 = UpsampleConv(n_input, n_output, k_size) 597 | self.conv2 = torch.nn.Conv2d(n_output, n_output, k_size, padding=(k_size-1)//2) 598 | self.conv_shortcut = UpsampleConv(n_input, n_output, k_size) 599 | self.out_dim = n_output 600 | elif resample == 'down': 601 | self.conv1 = torch.nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2) 602 | self.conv2 = ConvMeanPool(n_input, n_output, k_size) 603 | self.conv_shortcut = ConvMeanPool(n_input, n_output, k_size) 604 | self.out_dim = n_output 605 | self.ln_dims = [n_input, spatial_dim, spatial_dim] # Define the dimensions for layer normalization. 606 | else: 607 | self.conv1 = torch.nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2) 608 | self.conv2 = torch.nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2) 609 | self.conv_shortcut = None # Identity 610 | self.out_dim = n_input 611 | self.ln_dims = [n_input, spatial_dim, spatial_dim] 612 | 613 | self.model = torch.nn.Sequential( 614 | torch.nn.BatchNorm2d(n_input) if bn else torch.nn.LayerNorm(self.ln_dims), 615 | torch.nn.ReLU(inplace=True), 616 | self.conv1, 617 | torch.nn.BatchNorm2d(self.out_dim) if bn else torch.nn.LayerNorm(self.ln_dims), 618 | torch.nn.ReLU(inplace=True), 619 | self.conv2, 620 | ) 621 | 622 | def forward(self, x): 623 | if self.conv_shortcut is None: 624 | return x + self.model(x) 625 | else: 626 | return self.conv_shortcut(x) + self.model(x) 627 | 628 | class DiscBlock1(torch.nn.Module): 629 | def __init__(self, n_output): 630 | super(DiscBlock1, self).__init__() 631 | 632 | self.conv1 = torch.nn.Conv2d(3, n_output, 3, padding=(3-1)//2) 633 | self.conv2 = ConvMeanPool(n_output, n_output, 1) 634 | self.conv_shortcut = MeanPoolConv(3, n_output, 1) 635 | 636 | self.model = torch.nn.Sequential( 637 | self.conv1, 638 | torch.nn.ReLU(inplace=True), 639 | self.conv2 640 | ) 641 | 642 | def forward(self, x): 643 | return self.conv_shortcut(x) + self.model(x) 644 | 645 | class DCGAN_G(torch.nn.Module): 646 | def __init__(self): 647 | super(DCGAN_G, self).__init__() 648 | 649 | self.model = torch.nn.Sequential( # 128 x 1 x 1 650 | torch.nn.ConvTranspose2d(128, 128, 4, 1, 0), # 128 x 4 x 4 651 | ResidualBlock(128, 128, 3, resample='up'), # 128 x 8 x 8 652 | ResidualBlock(128, 128, 3, resample='up'), # 128 x 16 x 16 653 | ResidualBlock(128, 128, 3, resample='up'), # 128 x 32 x 32 654 | torch.nn.BatchNorm2d(128), 655 | torch.nn.ReLU(inplace=True), 656 | torch.nn.Conv2d(128, 3, 3, padding=(3-1)//2), # 3 x 32 x 32 657 | torch.nn.Tanh() 658 | ) 659 | 660 | def forward(self, z): 661 | img = self.model(z) 662 | return img 663 | 664 | class DCGAN_D(torch.nn.Module): 665 | def __init__(self): 666 | super(DCGAN_D, self).__init__() 667 | n_output = 128 668 | ''' 669 | This is a parameter but since we experiment with a single size 670 | of 3 x 32 x 32 images, it is hardcoded here. 671 | ''' 672 | 673 | self.DiscBlock1 = DiscBlock1(n_output) # 128 x 16 x 16 674 | 675 | self.model = torch.nn.Sequential( 676 | ResidualBlock(n_output, n_output, 3, resample='down', bn=False, spatial_dim=16), # 128 x 8 x 8 677 | ResidualBlock(n_output, n_output, 3, resample=None, bn=False, spatial_dim=8), # 128 x 8 x 8 678 | ResidualBlock(n_output, n_output, 3, resample=None, bn=False, spatial_dim=8), # 128 x 8 x 8 679 | torch.nn.ReLU(inplace=True), 680 | ) 681 | self.l1 = torch.nn.Sequential(torch.nn.Linear(128, 1)) # 128 x 1 682 | 683 | def forward(self, x): 684 | # x = x.view(-1, 3, 32, 32) 685 | y = self.DiscBlock1(x) 686 | y = self.model(y) 687 | y = y.view(x.size(0), 128, -1) 688 | y = y.mean(dim=2) 689 | out = self.l1(y).unsqueeze_(1).unsqueeze_(2) # or *.view(x.size(0), 128, 1, 1, 1) 690 | return out.view(-1) 691 | 692 | ## Initialization 693 | G = DCGAN_G() 694 | D = DCGAN_D() 695 | 696 | # Initialize weights 697 | def weights_init(m): 698 | classname = m.__class__.__name__ 699 | if classname.find('Conv') != -1: 700 | m.weight.data.normal_(0.0, 0.02) 701 | elif classname.find('BatchNorm') != -1: 702 | # Estimated variance, must be around 1 703 | m.weight.data.normal_(1.0, 0.02) 704 | # Estimated mean, must be around 0 705 | m.bias.data.fill_(0) 706 | if param.arch < 2: 707 | G.apply(weights_init) 708 | D.apply(weights_init) 709 | print("Initialized weights") 710 | print("Initialized weights", file=log_output) 711 | 712 | # Criterion 713 | BCE_stable = torch.nn.BCEWithLogitsLoss() 714 | 715 | # Soon to be variables 716 | x = torch.FloatTensor(param.batch_size, param.n_colors, param.image_size, param.image_size) 717 | x_fake = torch.FloatTensor(param.batch_size, param.n_colors, param.image_size, param.image_size) 718 | y = torch.FloatTensor(param.batch_size) 719 | y2 = torch.FloatTensor(param.batch_size) 720 | # Weighted sum of fake and real image, for gradient penalty 721 | x_both = torch.FloatTensor(param.batch_size, param.n_colors, param.image_size, param.image_size) 722 | z = torch.FloatTensor(param.batch_size, param.z_size, 1, 1) 723 | # Uniform weight 724 | u = torch.FloatTensor(param.batch_size, 1, 1, 1) 725 | # This is to see during training, size and values won't change 726 | z_test = torch.FloatTensor(param.batch_size, param.z_size, 1, 1).normal_(0, 1) 727 | # For the gradients, we need to specify which one we want and want them all 728 | grad_outputs = torch.ones(param.batch_size) 729 | # For when calculating the approximate bias with RaGANs and RcGANs (contains log(2), nothing more) 730 | log_2 = torch.FloatTensor(1) 731 | w_grad = torch.FloatTensor([param.penalty]) # lagrange multipliers if using augmented lagrangian (initialized at given penalty value) 732 | 733 | # Everything cuda 734 | if param.cuda: 735 | G = G.cuda() 736 | D = D.cuda() 737 | BCE_stable.cuda() 738 | x = x.cuda() 739 | x_fake = x_fake.cuda() 740 | x_both = x_both.cuda() 741 | w_grad = w_grad.cuda() 742 | y = y.cuda() 743 | y2 = y2.cuda() 744 | u = u.cuda() 745 | z = z.cuda() 746 | z_test = z_test.cuda() 747 | grad_outputs = grad_outputs.cuda() 748 | log_2 = log_2.cuda() 749 | 750 | # Now Variables 751 | x = Variable(x) 752 | x_fake = Variable(x_fake) 753 | y = Variable(y) 754 | y2 = Variable(y2) 755 | z = Variable(z) 756 | z_test = Variable(z_test) 757 | w_grad = Variable(w_grad, requires_grad=True) 758 | 759 | log_2.fill_(2) 760 | log_2 = torch.log(log_2) 761 | 762 | # Based on DCGAN paper, they found using betas[0]=.50 better. 763 | # betas[0] represent is the weight given to the previous mean of the gradient 764 | # betas[1] is the weight given to the previous variance of the gradient 765 | optimizerD = torch.optim.Adam(D.parameters(), lr=param.lr_D, betas=(param.beta1, param.beta2), weight_decay=param.weight_decay) 766 | optimizerG = torch.optim.Adam(G.parameters(), lr=param.lr_G, betas=(param.beta1, param.beta2), weight_decay=param.weight_decay) 767 | 768 | # exponential weight decay on lr 769 | decayD = torch.optim.lr_scheduler.ExponentialLR(optimizerD, gamma=1-param.decay) 770 | decayG = torch.optim.lr_scheduler.ExponentialLR(optimizerG, gamma=1-param.decay) 771 | 772 | # Load existing models 773 | if param.load is not None: 774 | checkpoint = torch.load(param.load) 775 | current_set_images = checkpoint['current_set_images'] 776 | iter_offset = checkpoint['i'] # iter_offset = checkpoint['i'] 777 | G.load_state_dict(checkpoint['G_state']) 778 | D.load_state_dict(checkpoint['D_state']) 779 | optimizerG.load_state_dict(checkpoint['G_optimizer']) 780 | optimizerD.load_state_dict(checkpoint['D_optimizer']) 781 | decayG.load_state_dict(checkpoint['G_scheduler']) 782 | decayD.load_state_dict(checkpoint['D_scheduler']) 783 | z_test.copy_(checkpoint['z_test']) 784 | del checkpoint 785 | print(f'Resumed from iteration {current_set_images*param.gen_every}.') 786 | else: 787 | current_set_images = 0 788 | iter_offset = 0 789 | 790 | print(G) 791 | print(G, file=log_output) 792 | print(D) 793 | print(D, file=log_output) 794 | 795 | ## Fitting model 796 | for i in range(iter_offset, param.n_iter): 797 | 798 | # Fake images saved 799 | if i % param.print_every == 0: 800 | fake_test = G(z_test) 801 | vutils.save_image(fake_test.data, '%s/images/fake_samples_iter%05d.png' % (base_dir, i), normalize=True) 802 | 803 | for p in D.parameters(): 804 | p.requires_grad = True 805 | 806 | for t in range(param.Diters): 807 | 808 | ######################## 809 | # (1) Update D network # 810 | ######################## 811 | 812 | D.zero_grad() 813 | images = random_sample.__next__() 814 | # Mostly necessary for the last one because if N might not be a multiple of batch_size 815 | current_batch_size = images.size(0) 816 | if param.cuda: 817 | images = images.cuda() 818 | # Transfer batch of images to x 819 | x.resize_as_(images).copy_(images) 820 | del images 821 | y_pred = D(x) 822 | 823 | if param.show_graph and i == 0: 824 | # Visualization of the autograd graph 825 | d = pv.make_dot(y_pred, D.state_dict()) 826 | d.view() 827 | 828 | if param.loss_D in [1,2,3,4]: 829 | # Train with real data 830 | y.resize_(current_batch_size).fill_(1) 831 | if param.loss_D == 1: 832 | errD_real = BCE_stable(y_pred, y) 833 | if param.loss_D == 2: 834 | errD_real = torch.mean((y_pred - y) ** 2) 835 | #a = torch.abs(y_pred - y) 836 | #errD_real = torch.mean(a**(1+torch.log(1+a**4))) 837 | if param.loss_D == 4: 838 | errD_real = -torch.mean(y_pred) 839 | if param.loss_D == 3: 840 | errD_real = torch.mean(torch.nn.ReLU()(1.0 - y_pred)) 841 | 842 | # Train with fake data 843 | z.resize_(current_batch_size, param.z_size, 1, 1).normal_(0, 1) 844 | fake = G(z) 845 | x_fake.resize_(fake.data.size()).copy_(fake.data) 846 | y2.resize_(current_batch_size).fill_(0) 847 | # Detach y_pred from the neural network G and put it inside D 848 | y_pred_fake = D(x_fake.detach()) 849 | if param.loss_D == 1: 850 | errD_fake = BCE_stable(y_pred_fake, y2) 851 | if param.loss_D == 2: 852 | errD_fake = torch.mean((y_pred_fake) ** 2) 853 | #a = torch.abs(y_pred_fake - y) 854 | #errD_fake = torch.mean(a**(1+torch.log(1+a**2))) 855 | if param.loss_D == 4: 856 | errD_fake = torch.mean(y_pred_fake) 857 | if param.loss_D == 3: 858 | errD_fake = torch.mean(torch.nn.ReLU()(1.0 + y_pred_fake)) 859 | errD = errD_real + errD_fake 860 | #print(errD) 861 | else: 862 | y.resize_(current_batch_size).fill_(1) 863 | y2.resize_(current_batch_size).fill_(0) 864 | z.resize_(current_batch_size, param.z_size, 1, 1).normal_(0, 1) 865 | fake = G(z) 866 | x_fake.resize_(fake.data.size()).copy_(fake.data) 867 | y_pred_fake = D(x_fake.detach()) 868 | 869 | # Relativistic average GANs 870 | if param.loss_D == 11: 871 | errD = BCE_stable(y_pred - torch.mean(y_pred_fake), y) + BCE_stable(y_pred_fake - torch.mean(y_pred), y2) 872 | if param.loss_D == 12: 873 | if param.no_bias: 874 | errD = torch.mean((y_pred - torch.mean(y_pred_fake) - y) ** 2) + torch.mean((y_pred_fake - torch.mean(y_pred) + y) ** 2) - (torch.var(y_pred, dim=0)+torch.var(y_pred_fake, dim=0))/param.batch_size 875 | else: 876 | errD = torch.mean((y_pred - torch.mean(y_pred_fake) - y) ** 2) + torch.mean((y_pred_fake - torch.mean(y_pred) + y) ** 2) 877 | if param.loss_D == 13: 878 | errD = torch.mean(torch.nn.ReLU()(1.0 - (y_pred - torch.mean(y_pred_fake)))) + torch.mean(torch.nn.ReLU()(1.0 + (y_pred_fake - torch.mean(y_pred)))) 879 | 880 | # Relativistic centered GANs 881 | if param.loss_D == 21: 882 | full_mean = (torch.mean(y_pred) + torch.mean(y_pred_fake))/2 883 | errD = BCE_stable(y_pred - full_mean, y) + BCE_stable(y_pred_fake - full_mean, y2) 884 | if param.loss_D == 22: 885 | full_mean = (torch.mean(y_pred) + torch.mean(y_pred_fake))/2 886 | if param.no_bias: 887 | errD = torch.mean((y_pred - full_mean - y) ** 2) + torch.mean((y_pred_fake - full_mean + y) ** 2) + (torch.var(y_pred, dim=0)+torch.var(y_pred_fake, dim=0))/(2*param.batch_size) 888 | else: 889 | errD = torch.mean((y_pred - full_mean - y) ** 2) + torch.mean((y_pred_fake - full_mean + y) ** 2) 890 | if param.loss_D == 23: 891 | full_mean = (torch.mean(y_pred) + torch.mean(y_pred_fake))/2 892 | errD = torch.mean(torch.nn.ReLU()(1.0 - (y_pred - full_mean))) + torch.mean(torch.nn.ReLU()(1.0 + (y_pred_fake - full_mean))) 893 | 894 | # Relativistic paired GANs (Without the MVUE) 895 | if param.loss_D == 31: 896 | errD = 2*BCE_stable(y_pred - y_pred_fake, y) 897 | if param.loss_D == 32: 898 | errD = 2*torch.mean((y_pred - y_pred_fake - y) ** 2) 899 | if param.loss_D == 33: 900 | errD = 2*torch.mean(torch.nn.ReLU()(1.0 - (y_pred - y_pred_fake))) 901 | 902 | if param.loss_D in [41,42,43]: 903 | 904 | # Relativistic paired GANs (MVUE, slower) 905 | # Creating cartesian product substraction, very demanding sadly O(k^2), where k is the batch size 906 | grid_x, grid_y = torch.meshgrid([y_pred, y_pred_fake]) 907 | y_pred_subst = (grid_x - grid_y) 908 | y.resize_(current_batch_size,current_batch_size).fill_(1) 909 | 910 | if param.loss_D == 41: 911 | errD = 2*BCE_stable(y_pred_subst, y) 912 | if param.loss_D == 42: 913 | errD = 2*torch.mean((y_pred_subst - y) ** 2) 914 | if param.loss_D == 43: 915 | errD = 2*torch.mean(torch.nn.ReLU()(1.0 - y_pred_subst)) 916 | 917 | errD_real = errD 918 | errD_fake = errD 919 | 920 | errD.backward(retain_graph=True) 921 | 922 | if (param.loss_D in [4] or param.grad_penalty) and (not param.no_grad_penalty): 923 | # Gradient penalty 924 | u.resize_(current_batch_size, 1, 1, 1) 925 | u.uniform_(0, 1) 926 | if param.real_only: 927 | x_both = x.data 928 | elif param.fake_only: 929 | x_both = x_fake.data 930 | else: 931 | x_both = x.data*u + x_fake.data*(1-u) 932 | if param.cuda: 933 | x_both = x_both.cuda() 934 | 935 | # We only want the gradients with respect to x_both 936 | x_both = Variable(x_both, requires_grad=True) 937 | y0 = D(x_both) 938 | grad = torch.autograd.grad(outputs=y0, 939 | inputs=x_both, grad_outputs=grad_outputs, 940 | retain_graph=True, create_graph=True, 941 | only_inputs=True)[0] 942 | x_both.requires_grad_(False) 943 | sh = grad.shape 944 | grad = grad.view(current_batch_size,-1) 945 | 946 | if param.l1_margin_no_abs: 947 | grad_abs = torch.abs(grad) 948 | else: 949 | grad_abs = grad 950 | 951 | if param.l1_margin: 952 | grad_norm , _ = torch.max(grad_abs,1) 953 | elif param.l1_margin_smoothmax: 954 | grad_norm = torch.sum(grad_abs*torch.exp(param.smoothmax*grad_abs))/torch.sum(torch.exp(param.smoothmax*grad_abs)) 955 | elif param.l1_margin_logsumexp: 956 | grad_norm = torch.logsumexp(grad_abs,1) 957 | elif param.linf_margin: 958 | grad_norm = grad.norm(1,1) 959 | else: 960 | grad_norm = grad.norm(2,1) 961 | 962 | if param.penalty_type == 'squared-diff': 963 | constraint = (grad_norm-1).pow(2) 964 | elif param.penalty_type == 'clamp': 965 | constraint = grad_norm.clamp(min=1.) - 1. 966 | elif param.penalty_type == 'squared-clamp': 967 | constraint = (grad_norm.clamp(min=1.) - 1.).pow(2) 968 | elif param.penalty_type == 'squared': 969 | constraint = grad_norm.pow(2) 970 | elif param.penalty_type == 'TV': 971 | constraint = grad_norm 972 | elif param.penalty_type == 'abs': 973 | constraint = torch.abs(grad_norm-1) 974 | elif param.penalty_type == 'hinge': 975 | constraint = torch.nn.ReLU()(grad_norm - 1) 976 | elif param.penalty_type == 'hinge2': 977 | constraint = (torch.nn.ReLU()(grad_norm - 1)).pow(2) 978 | else: 979 | raise ValueError('penalty type %s is not valid'%param.penalty_type) 980 | 981 | if param.reduction == 'mean': 982 | constraint = constraint.mean() 983 | elif param.reduction == 'max': 984 | constraint = constraint.max() 985 | elif param.reduction == 'softmax': 986 | sm = constraint.softmax(0) 987 | constraint = (sm*constraint).sum() 988 | else: 989 | raise ValueError('reduction type %s is not valid'%param.reduction) 990 | if param.print_grad: 991 | print(constraint) 992 | print(constraint, file=log_output) 993 | 994 | if param.grad_penalty_aug: 995 | grad_penalty = (-w_grad*constraint + (param.rho/2)*(constraint)**2) 996 | grad_penalty.backward(retain_graph=True) 997 | else: 998 | grad_penalty = param.penalty*constraint 999 | grad_penalty.backward(retain_graph=True) 1000 | else: 1001 | grad_penalty = 0. 1002 | 1003 | optimizerD.step() 1004 | # Augmenten Lagrangian 1005 | if param.grad_penalty_aug: 1006 | w_grad.data += param.rho * w_grad.grad.data 1007 | w_grad.grad.zero_() 1008 | 1009 | 1010 | ######################## 1011 | # (2) Update G network # 1012 | ######################## 1013 | 1014 | # Make it a tiny bit faster 1015 | for p in D.parameters(): 1016 | p.requires_grad = False 1017 | 1018 | for t in range(param.Giters): 1019 | 1020 | G.zero_grad() 1021 | if param.resample == 1: 1022 | y.resize_(current_batch_size).fill_(1) 1023 | z.resize_(current_batch_size, param.z_size, 1, 1).normal_(0, 1) 1024 | fake = G(z) 1025 | 1026 | if param.loss_D not in [1, 2, 3, 4]: 1027 | images = random_sample.__next__() 1028 | current_batch_size = images.size(0) 1029 | if param.cuda: 1030 | images = images.cuda() 1031 | x.resize_as_(images).copy_(images) 1032 | del images 1033 | 1034 | y_pred_fake = D(fake) 1035 | y2.resize_(current_batch_size).fill_(0) 1036 | if param.loss_D == 1: 1037 | errG = BCE_stable(y_pred_fake, y) 1038 | if param.loss_D == 2: 1039 | errG = torch.mean((y_pred_fake - y) ** 2) 1040 | if param.loss_D == 4: 1041 | errG = -torch.mean(y_pred_fake) 1042 | if param.loss_D == 3: 1043 | errG = -torch.mean(y_pred_fake) 1044 | 1045 | # Relativistic average GANs 1046 | if param.loss_D == 11: 1047 | y_pred = D(x) 1048 | errG = BCE_stable(y_pred - torch.mean(y_pred_fake), y2) + BCE_stable(y_pred_fake - torch.mean(y_pred), y) 1049 | if param.loss_D == 12: 1050 | y_pred = D(x) 1051 | if param.no_bias: 1052 | errG = torch.mean((y_pred - torch.mean(y_pred_fake) + y) ** 2) + torch.mean((y_pred_fake - torch.mean(y_pred) - y) ** 2) - (torch.var(y_pred_fake, dim=0)/param.batch_size) 1053 | else: 1054 | errG = torch.mean((y_pred - torch.mean(y_pred_fake) + y) ** 2) + torch.mean((y_pred_fake - torch.mean(y_pred) - y) ** 2) 1055 | if param.loss_D == 13: 1056 | y_pred = D(x) 1057 | errG = torch.mean(torch.nn.ReLU()(1.0 + (y_pred - torch.mean(y_pred_fake)))) + torch.mean(torch.nn.ReLU()(1.0 - (y_pred_fake - torch.mean(y_pred)))) 1058 | 1059 | if param.loss_D == 21: 1060 | y_pred = D(x) 1061 | full_mean = (torch.mean(y_pred) + torch.mean(y_pred_fake))/2 1062 | errG = BCE_stable(y_pred - full_mean, y2) + BCE_stable(y_pred_fake - full_mean, y) 1063 | if param.loss_D == 22: # (y_hat-1)^2 + (y_hat+1)^2 1064 | y_pred = D(x) 1065 | full_mean = (torch.mean(y_pred) + torch.mean(y_pred_fake))/2 1066 | if param.no_bias: 1067 | errG = torch.mean((y_pred - full_mean + y) ** 2) + torch.mean((y_pred_fake - full_mean - y) ** 2) + (torch.var(y_pred_fake, dim=0)/(2*param.batch_size)) 1068 | else: 1069 | errG = torch.mean((y_pred - full_mean + y) ** 2) + torch.mean((y_pred_fake - full_mean - y) ** 2) 1070 | if param.loss_D == 23: 1071 | y_pred = D(x) 1072 | full_mean = (torch.mean(y_pred) + torch.mean(y_pred_fake))/2 1073 | errG = torch.mean(torch.nn.ReLU()(1.0 + (y_pred - full_mean))) + torch.mean(torch.nn.ReLU()(1.0 - (y_pred_fake - full_mean))) 1074 | 1075 | # Relativistic paired GANs (Without the MVUE) 1076 | if param.loss_D == 31: 1077 | y_pred = D(x) 1078 | errG = BCE_stable(y_pred_fake - y_pred, y) 1079 | if param.loss_D == 32: # (y_hat-1)^2 + (y_hat+1)^2 1080 | y_pred = D(x) 1081 | errG = torch.mean((y_pred_fake - y_pred - y) ** 2) 1082 | if param.loss_D == 33: 1083 | y_pred = D(x) 1084 | errG = torch.mean(torch.nn.ReLU()(1.0 - (y_pred_fake - y_pred))) 1085 | 1086 | if param.loss_D in [41,42,43]: 1087 | 1088 | # Relativistic paired GANs 1089 | # Creating cartesian product substraction, very demanding sadly O(k^2), where k is the batch size 1090 | y_pred = D(x) 1091 | grid_x, grid_y = torch.meshgrid([y_pred, y_pred_fake]) 1092 | y_pred_subst = grid_y - grid_x 1093 | y.resize_(current_batch_size,current_batch_size).fill_(1) 1094 | 1095 | if param.loss_D == 41: 1096 | errG = 2*BCE_stable(y_pred_subst, y) 1097 | if param.loss_D == 42: # (y_hat-1)^2 + (y_hat+1)^2 1098 | errG = 2*torch.mean((y_pred_subst - y) ** 2) 1099 | if param.loss_D == 43: 1100 | errG = 2*torch.mean(torch.nn.ReLU()(1.0 - y_pred_subst)) 1101 | 1102 | errG.backward() 1103 | D_G = y_pred_fake.data.mean() 1104 | optimizerG.step() 1105 | decayD.step() 1106 | decayG.step() 1107 | 1108 | # Log results so we can see them in TensorBoard after 1109 | #log_value('Diff', -(errD.data.item()+errG.data.item()), i) 1110 | #log_value('errD', errD.data.item(), i) 1111 | #log_value('errG', errG.data.item(), i) 1112 | 1113 | if (i+1) % param.print_every == 0: 1114 | end = time.time() 1115 | fmt = '[%d] Diff: %.4f loss_D: %.4f loss_G: %.4f time:%.4f' 1116 | s = fmt % (i, -errD.data.item()+errG.data.item(), errD.data.item(), errG.data.item(), end - start) 1117 | print(s) 1118 | print(s, file=log_output) 1119 | 1120 | # Evaluation metrics 1121 | if (i+1) % param.gen_every == 0: 1122 | 1123 | current_set_images += 1 1124 | 1125 | # Save models 1126 | if param.save: 1127 | if not os.path.exists('%s/models/' % (param.logs_folder)): 1128 | os.makedirs('%s/models/' % (param.logs_folder)) 1129 | torch.save({ 1130 | 'i': i + 1, 1131 | 'current_set_images': current_set_images, 1132 | 'G_state': G.state_dict(), 1133 | 'D_state': D.state_dict(), 1134 | 'G_optimizer': optimizerG.state_dict(), 1135 | 'D_optimizer': optimizerD.state_dict(), 1136 | 'G_scheduler': decayG.state_dict(), 1137 | 'D_scheduler': decayD.state_dict(), 1138 | 'z_test': z_test, 1139 | }, '%s/models/state_%02d.pth' % (param.logs_folder, current_set_images)) 1140 | s = 'Models saved' 1141 | print(s) 1142 | print(s, file=log_output) 1143 | 1144 | # Delete previously existing images 1145 | if os.path.exists('%s/%01d/' % (param.extra_folder, current_set_images)): 1146 | for root, dirs, files in os.walk('%s/%01d/' % (param.extra_folder, current_set_images)): 1147 | for f in files: 1148 | os.unlink(os.path.join(root, f)) 1149 | else: 1150 | os.makedirs('%s/%01d/' % (param.extra_folder, current_set_images)) 1151 | 1152 | # Generate 50k images for FID/Inception to be calculated later (not on this script, since running both tensorflow and pytorch at the same time cause issues) 1153 | ext_curr = 0 1154 | z_extra = torch.FloatTensor(100, param.z_size, 1, 1) 1155 | if param.cuda: 1156 | z_extra = z_extra.cuda() 1157 | for ext in range(int(param.gen_extra_images/100)): 1158 | fake_test = G(Variable(z_extra.normal_(0, 1))) 1159 | for ext_i in range(100): 1160 | vutils.save_image((fake_test[ext_i].data*.50)+.50, '%s/%01d/fake_samples_%05d.png' % (param.extra_folder, current_set_images,ext_curr), normalize=False, padding=0) 1161 | ext_curr += 1 1162 | del z_extra 1163 | del fake_test 1164 | # Later use this command to get FID of first set: 1165 | # python fid.py "/home/alexia/Output/Extra/01" "/home/alexia/Datasets/fid_stats_cifar10_train.npz" -i "/home/alexia/Inception" --gpu "0" 1166 | end = time.time() 1167 | fmt = 'Total time: [%.4f]' % (end - start) 1168 | print(fmt) 1169 | print(fmt, file=log_output) -------------------------------------------------------------------------------- /Code/create_FID_stats.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import glob 5 | import numpy as np 6 | import fid_new as fid 7 | from scipy.misc import imread 8 | import tensorflow as tf 9 | 10 | ######## 11 | # PATHS 12 | ######## 13 | 14 | import argparse 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--data_path', default='/Datasets/CIFAR10_images/train') 17 | parser.add_argument('--inception_path', default='/models/Inception') 18 | parser.add_argument('--output_path', default='CIFAR_fid_stats.npz') 19 | parser.add_argument('--batch_size', default=100) 20 | param = parser.parse_args() 21 | 22 | # Add local variable to folders so we work in the local drive 23 | param.data_path = os.environ["SLURM_TMPDIR"] + param.data_path 24 | param.output_path = os.environ["SLURM_TMPDIR"] + param.output_path 25 | param.inception_path = os.environ["SLURM_TMPDIR"] + param.inception_path 26 | 27 | data_path = param.data_path # set path to training set images 28 | output_path = param.output_path # path for where to store the statistics 29 | # if you have downloaded and extracted 30 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 31 | # set this path to the directory where the extracted files are, otherwise 32 | # just set it to None and the script will later download the files for you 33 | inception_path = param.inception_path 34 | print(inception_path) 35 | print("check for inception model..", end=" ", flush=True) 36 | inception_path = fid.check_or_download_inception(inception_path) # download inception if necessary 37 | print("ok") 38 | 39 | # loads all images into memory (this might require a lot of RAM!) 40 | print("load images..", end=" " , flush=True) 41 | image_list = glob.glob(os.path.join(data_path, '*.jpg')) 42 | image_list.extend(glob.glob(os.path.join(data_path, '*.png'))) 43 | #images = np.array([imread(str(fn)).astype(np.float32) for fn in image_list]) 44 | #print("%d images found and loaded" % len(images)) 45 | print("%d images found and loaded" % len(image_list)) 46 | 47 | print("create inception graph..", end=" ", flush=True) 48 | fid.create_inception_graph(inception_path) # load the graph into the current TF graph 49 | print("ok") 50 | 51 | print("calculte FID stats..", end=" ", flush=True) 52 | with tf.Session() as sess: 53 | sess.run(tf.global_variables_initializer()) 54 | mu, sigma = fid.calculate_activation_statistics_from_files(image_list, sess, batch_size=param.batch_size) 55 | np.savez_compressed(output_path, mu=mu, sigma=sigma) 56 | print("finished") -------------------------------------------------------------------------------- /Code/experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd /home/jolicoea/my_projects/MaxMargin/Code 4 | 5 | ### Data export 6 | bash startup_tmp.sh dir1="CIFAR10" dir2="Meow_64x64" dir3="Meow_256x256" 7 | 8 | 9 | 10 | ## CIFAR-10 (.50,.99 adam) 11 | 12 | python GAN.py --loss_D 3 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --l1_margin 13 | bash fid_script.sh 10 "HingeGAN CIFAR-10 lr .0002 linf squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 14 | 15 | 16 | python GAN.py --loss_D 4 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --l1_margin 17 | bash fid_script.sh 10 "WGAN CIFAR-10 lr .0002 linf squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 18 | 19 | 20 | python GAN.py --loss_D 3 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --penalty-type 'hinge' --l1_margin 21 | bash fid_script.sh 10 "HingeGAN CIFAR-10 lr .0002 linf hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 22 | 23 | 24 | python GAN.py --loss_D 4 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --penalty-type 'hinge' --l1_margin 25 | bash fid_script.sh 10 "WGAN CIFAR-10 lr .0002 linf hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 26 | 27 | 28 | 29 | python GAN.py --loss_D 3 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --l1_margin_smoothmax --smoothmax 1 30 | bash fid_script.sh 10 "HingeGAN CIFAR-10 lr .0002 linf-smooth1 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 31 | 32 | 33 | python GAN.py --loss_D 4 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --l1_margin_smoothmax --smoothmax 1 34 | bash fid_script.sh 10 "WGAN CIFAR-10 lr .0002 linf-smooth1 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 35 | 36 | 37 | python GAN.py --loss_D 3 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --penalty-type 'hinge' --l1_margin_smoothmax --smoothmax 1 38 | bash fid_script.sh 10 "HingeGAN CIFAR-10 lr .0002 linf-smooth1 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 39 | 40 | 41 | python GAN.py --loss_D 4 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --penalty-type 'hinge' --l1_margin_smoothmax --smoothmax 1 42 | bash fid_script.sh 10 "WGAN CIFAR-10 lr .0002 linf-smooth1 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 43 | 44 | 45 | 46 | python GAN.py --loss_D 3 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True 47 | bash fid_script.sh 10 "HingeGAN CIFAR-10 lr .0002 l2 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 48 | 49 | 50 | python GAN.py --loss_D 4 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True 51 | bash fid_script.sh 10 "WGAN CIFAR-10 lr .0002 l2 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 52 | 53 | 54 | python GAN.py --loss_D 3 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --penalty-type 'hinge' 55 | bash fid_script.sh 10 "HingeGAN CIFAR-10 lr .0002 l2 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 56 | 57 | 58 | python GAN.py --loss_D 4 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --penalty-type 'hinge' 59 | bash fid_script.sh 10 "WGAN CIFAR-10 lr .0002 l2 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 60 | 61 | 62 | 63 | python GAN.py --loss_D 3 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --linf_margin 64 | bash fid_script.sh 10 "HingeGAN CIFAR-10 lr .0002 l1 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 65 | 66 | 67 | python GAN.py --loss_D 4 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --linf_margin 68 | bash fid_script.sh 10 "WGAN CIFAR-10 lr .0002 l1 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 69 | 70 | 71 | python GAN.py --loss_D 3 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --penalty-type 'hinge' --linf_margin 72 | bash fid_script.sh 10 "HingeGAN CIFAR-10 lr .0002 l1 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 73 | 74 | 75 | python GAN.py --loss_D 4 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --penalty-type 'hinge' --linf_margin 76 | bash fid_script.sh 10 "WGAN CIFAR-10 lr .0002 l1 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 77 | 78 | 79 | 80 | 81 | 82 | 83 | ## CIFAR-10 (0,.9 adam) 84 | 85 | python GAN.py --loss_D 3 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --beta1 0 --beta2 .90 --grad_penalty True --l1_margin 86 | bash fid_script.sh 10 "adam0/.50 HingeGAN CIFAR-10 lr .0002 linf squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 87 | 88 | 89 | python GAN.py --loss_D 4 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --beta1 0 --beta2 .90 --l1_margin 90 | bash fid_script.sh 10 "WGAN CIFAR-10 lr .0002 linf squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 91 | 92 | 93 | python GAN.py --loss_D 3 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --beta1 0 --beta2 .90 --grad_penalty True --penalty-type 'hinge' --l1_margin 94 | bash fid_script.sh 10 "HingeGAN CIFAR-10 lr .0002 linf hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 95 | 96 | 97 | python GAN.py --loss_D 4 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --beta1 0 --beta2 .90 --penalty-type 'hinge' --l1_margin 98 | bash fid_script.sh 10 "WGAN CIFAR-10 lr .0002 linf hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 99 | 100 | 101 | 102 | python GAN.py --loss_D 3 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --beta1 0 --beta2 .90 --grad_penalty True --l1_margin_smoothmax --smoothmax 1 103 | bash fid_script.sh 10 "HingeGAN CIFAR-10 lr .0002 linf-smooth1 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 104 | 105 | 106 | python GAN.py --loss_D 4 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --beta1 0 --beta2 .90 --l1_margin_smoothmax --smoothmax 1 107 | bash fid_script.sh 10 "WGAN CIFAR-10 lr .0002 linf-smooth1 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 108 | 109 | 110 | python GAN.py --loss_D 3 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --beta1 0 --beta2 .90 --grad_penalty True --penalty-type 'hinge' --l1_margin_smoothmax --smoothmax 1 111 | bash fid_script.sh 10 "HingeGAN CIFAR-10 lr .0002 linf-smooth1 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 112 | 113 | 114 | python GAN.py --loss_D 4 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --beta1 0 --beta2 .90 --penalty-type 'hinge' --l1_margin_smoothmax --smoothmax 1 115 | bash fid_script.sh 10 "WGAN CIFAR-10 lr .0002 linf-smooth1 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 116 | 117 | 118 | 119 | python GAN.py --loss_D 3 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --beta1 0 --beta2 .90 --grad_penalty True 120 | bash fid_script.sh 10 "HingeGAN CIFAR-10 lr .0002 l2 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 121 | 122 | 123 | python GAN.py --loss_D 4 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --beta1 0 --beta2 .90 124 | bash fid_script.sh 10 "WGAN CIFAR-10 lr .0002 l2 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 125 | 126 | 127 | python GAN.py --loss_D 3 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --beta1 0 --beta2 .90 --grad_penalty True --penalty-type 'hinge' 128 | bash fid_script.sh 10 "HingeGAN CIFAR-10 lr .0002 l2 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 129 | 130 | 131 | python GAN.py --loss_D 4 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --beta1 0 --beta2 .90 --penalty-type 'hinge' 132 | bash fid_script.sh 10 "WGAN CIFAR-10 lr .0002 l2 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 133 | 134 | 135 | 136 | python GAN.py --loss_D 3 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --beta1 0 --beta2 .90 --grad_penalty True --linf_margin 137 | bash fid_script.sh 10 "HingeGAN CIFAR-10 lr .0002 l1 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 138 | 139 | 140 | python GAN.py --loss_D 4 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --beta1 0 --beta2 .90 --linf_margin 141 | bash fid_script.sh 10 "WGAN CIFAR-10 lr .0002 l1 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 142 | 143 | 144 | python GAN.py --loss_D 3 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --beta1 0 --beta2 .90 --grad_penalty True --penalty-type 'hinge' --linf_margin 145 | bash fid_script.sh 10 "HingeGAN CIFAR-10 lr .0002 l1 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 146 | 147 | 148 | python GAN.py --loss_D 4 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --beta1 0 --beta2 .90 --penalty-type 'hinge' --linf_margin 149 | bash fid_script.sh 10 "WGAN CIFAR-10 lr .0002 l1 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 150 | 151 | 152 | 153 | 154 | 155 | ## Meow 64 156 | 157 | python GAN.py --loss_D 3 --input_folder '/Datasets/Meow_64x64' --image_size 64 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True 158 | bash fid_script.sh 10 "Hinge Meow-64 lr .0002 l-2 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/CAT_fid_stats64.npz" 159 | 160 | 161 | python GAN.py --loss_D 3 --input_folder '/Datasets/Meow_64x64' --image_size 64 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True --l1_margin --penalty-type 'hinge' 162 | bash fid_script.sh 10 "Hinge Meow-64 lr .0002 l-inf hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/CAT_fid_stats64.npz" 163 | 164 | 165 | python GAN.py --loss_D 3 --input_folder '/Datasets/Meow_64x64' --image_size 64 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True --penalty-type 'hinge' 166 | bash fid_script.sh 10 "HingeGAN Meow-64 lr .0002 l-2 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/CAT_fid_stats64.npz" 167 | 168 | 169 | 170 | python GAN.py --loss_D 4 --input_folder '/Datasets/Meow_64x64' --image_size 64 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True 171 | bash fid_script.sh 10 "WGAN Meow-64 lr .0002 l-2 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/CAT_fid_stats64.npz" 172 | 173 | 174 | python GAN.py --loss_D 4 --input_folder '/Datasets/Meow_64x64' --image_size 64 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True --l1_margin --penalty-type 'hinge' 175 | bash fid_script.sh 10 "WGAN Meow-64 lr .0002 l-inf hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/CAT_fid_stats64.npz" 176 | 177 | 178 | python GAN.py --loss_D 4 --input_folder '/Datasets/Meow_64x64' --image_size 64 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True --penalty-type 'hinge' 179 | bash fid_script.sh 10 "WGAN Meow-64 lr .0002 l-2 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/CAT_fid_stats64.npz" 180 | 181 | 182 | 183 | 184 | 185 | ## CIFAR-10 resnet 186 | 187 | python GAN.py --loss_D 3 --CIFAR10 True --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True --arch 2 188 | bash fid_script.sh 10 "Hinge CIFAR10-resnet lr .0002 l-2 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 189 | 190 | 191 | python GAN.py --loss_D 3 --CIFAR10 True --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True --arch 2 --l1_margin --penalty-type 'hinge' 192 | bash fid_script.sh 10 "Hinge CIFAR10-resnet lr .0002 l-inf hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 193 | 194 | 195 | python GAN.py --loss_D 3 --CIFAR10 True --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True --arch 2 --penalty-type 'hinge' 196 | bash fid_script.sh 10 "HingeGAN CIFAR10-resnet lr .0002 l-2 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 197 | 198 | 199 | 200 | python GAN.py --loss_D 4 --CIFAR10 True --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True --arch 2 201 | bash fid_script.sh 10 "WGAN CIFAR10-resnet lr .0002 l-2 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 202 | 203 | 204 | python GAN.py --loss_D 4 --CIFAR10 True --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True --arch 2 --l1_margin --penalty-type 'hinge' 205 | bash fid_script.sh 10 "WGAN CIFAR10-resnet lr .0002 l-inf hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 206 | 207 | 208 | python GAN.py --loss_D 4 --CIFAR10 True --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True --arch 2 --penalty-type 'hinge' 209 | bash fid_script.sh 10 "WGAN CIFAR10-resnet lr .0002 l-2 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 210 | 211 | 212 | 213 | python GAN.py --loss_D 3 --CIFAR10 True --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True --arch 2 214 | bash fid_script.sh 10 "Hinge CIFAR10-resnet lr .0002 l-2 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 215 | 216 | 217 | python GAN.py --loss_D 3 --CIFAR10 True --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True --arch 2 --l1_margin --penalty-type 'hinge' 218 | bash fid_script.sh 10 "Hinge CIFAR10-resnet lr .0002 l-inf hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 219 | 220 | 221 | python GAN.py --loss_D 3 --CIFAR10 True --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True --arch 2 --penalty-type 'hinge' 222 | bash fid_script.sh 10 "HingeGAN CIFAR10-resnet lr .0002 l-2 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 223 | 224 | 225 | 226 | python GAN.py --loss_D 4 --CIFAR10 True --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True --arch 2 227 | bash fid_script.sh 10 "WGAN CIFAR10-resnet lr .0002 l-2 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 228 | 229 | 230 | python GAN.py --loss_D 4 --CIFAR10 True --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True --arch 2 --l1_margin --penalty-type 'hinge' 231 | bash fid_script.sh 10 "WGAN CIFAR10-resnet lr .0002 l-inf hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 232 | 233 | 234 | python GAN.py --loss_D 4 --CIFAR10 True --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --G_h_size 64 --D_h_size 64 --grad_penalty True --arch 2 --penalty-type 'hinge' 235 | bash fid_script.sh 10 "WGAN CIFAR10-resnet lr .0002 l-2 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 236 | 237 | 238 | 239 | 240 | 241 | ## RaGANs 242 | 243 | python GAN.py --loss_D 13 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --l1_margin 244 | bash fid_script.sh 10 "RaHingeGAN CIFAR-10 lr .0002 l-inf squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 245 | 246 | python GAN.py --loss_D 13 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --l1_margin --penalty-type 'hinge' 247 | bash fid_script.sh 10 "RaHingeGAN CIFAR-10 lr .0002 l-inf hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 248 | 249 | 250 | python GAN.py --loss_D 13 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --l1_margin_smoothmax --smoothmax 1 251 | bash fid_script.sh 10 "RaHingeGAN CIFAR-10 lr .0002 l-inf squared-1 penalty smoothmax 1" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 252 | 253 | python GAN.py --loss_D 13 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --l1_margin_smoothmax --smoothmax 1 --penalty-type 'hinge' 254 | bash fid_script.sh 10 "RaHingeGAN CIFAR-10 lr .0002 l-inf hinge penalty smoothmax 1" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 255 | 256 | 257 | 258 | python GAN.py --loss_D 13 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True 259 | bash fid_script.sh 10 "RaHingeGAN CIFAR-10 lr .0002 l-2 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 260 | 261 | python GAN.py --loss_D 13 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --penalty-type 'hinge' 262 | bash fid_script.sh 10 "RaHingeGAN CIFAR-10 lr .0002 l-2 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 263 | 264 | 265 | python GAN.py --loss_D 13 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --linf_margin 266 | bash fid_script.sh 10 "RaHingeGAN CIFAR-10 lr .0002 l-1 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 267 | 268 | python GAN.py --loss_D 13 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --penalty-type 'hinge' --linf_margin 269 | bash fid_script.sh 10 "RaHingeGAN CIFAR-10 lr .0002 l-1 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 270 | 271 | 272 | 273 | 274 | 275 | ## RpGANs 276 | 277 | python GAN.py --loss_D 33 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --l1_margin 278 | bash fid_script.sh 10 "RpHingeGAN CIFAR-10 lr .0002 l-inf squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 279 | 280 | python GAN.py --loss_D 33 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --l1_margin --penalty-type 'hinge' 281 | bash fid_script.sh 10 "RpHingeGAN CIFAR-10 lr .0002 l-inf hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 282 | 283 | 284 | python GAN.py --loss_D 33 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --l1_margin_smoothmax --smoothmax 1 285 | bash fid_script.sh 10 "RpHingeGAN CIFAR-10 lr .0002 l-inf squared-1 penalty smoothmax 1" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 286 | 287 | python GAN.py --loss_D 33 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --l1_margin_smoothmax --smoothmax 1 --penalty-type 'hinge' 288 | bash fid_script.sh 10 "RpHingeGAN CIFAR-10 lr .0002 l-inf hinge penalty smoothmax 1" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 289 | 290 | 291 | 292 | python GAN.py --loss_D 33 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True 293 | bash fid_script.sh 10 "RpHingeGAN CIFAR-10 lr .0002 l-2 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 294 | 295 | python GAN.py --loss_D 33 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --penalty-type 'hinge' 296 | bash fid_script.sh 10 "RpHingeGAN CIFAR-10 lr .0002 l-2 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 297 | 298 | 299 | python GAN.py --loss_D 33 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --linf_margin 300 | bash fid_script.sh 10 "RpHingeGAN CIFAR-10 lr .0002 l-1 squared-1 penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 301 | 302 | python GAN.py --loss_D 33 --image_size 32 --seed 1 --lr_D .0002 --lr_G .0002 --batch_size 32 --Diters 1 --n_iter 100001 --gen_every 10000 --print_every 5000 --gen_extra_images 50000 --CIFAR10 True --grad_penalty True --penalty-type 'hinge' --linf_margin 303 | bash fid_script.sh 10 "RpHingeGAN CIFAR-10 lr .0002 l-1 hinge penalty" 10000 "$SLURM_TMPDIR/fid_stats/fid_stats_cifar10_train.npz" 304 | -------------------------------------------------------------------------------- /Code/fid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ''' Calculates the Frechet Inception Distance (FID) to evalulate GANs. 3 | 4 | The FID metric calculates the distance between two distributions of images. 5 | Typically, we have summary statistics (mean & covariance matrix) of one 6 | of these distributions, while the 2nd distribution is given by a GAN. 7 | 8 | When run as a stand-alone program, it compares the distribution of 9 | images that are stored as PNG/JPEG at a specified location with a 10 | distribution given by summary statistics (in pickle format). 11 | 12 | The FID is calculated by assuming that X_1 and X_2 are the activations of 13 | the pool_3 layer of the inception net for generated samples and real world 14 | samples respectivly. 15 | 16 | See --help to see further details. 17 | ''' 18 | 19 | from __future__ import absolute_import, division, print_function 20 | import numpy as np 21 | import os 22 | import gzip, pickle 23 | import tensorflow as tf 24 | from imageio import imread # previously "from scipy.misc import imread", but bugged now 25 | from scipy import linalg 26 | import pathlib 27 | import urllib 28 | 29 | config = tf.ConfigProto() 30 | config.gpu_options.allow_growth = True 31 | 32 | class InvalidFIDException(Exception): 33 | pass 34 | 35 | 36 | def create_inception_graph(pth): 37 | """Creates a graph from saved GraphDef file.""" 38 | # Creates graph from saved graph_def.pb. 39 | with tf.gfile.FastGFile( pth, 'rb') as f: 40 | graph_def = tf.GraphDef() 41 | graph_def.ParseFromString( f.read()) 42 | _ = tf.import_graph_def( graph_def, name='FID_Inception_Net') 43 | #------------------------------------------------------------------------------- 44 | 45 | 46 | # code for handling inception net derived from 47 | # https://github.com/openai/improved-gan/blob/master/inception_score/model.py 48 | # updated from here https://github.com/bioinf-jku/TTUR/issues/6 49 | def _get_inception_layer(sess): 50 | """Prepares inception net for batched usage and returns pool_3 layer. """ 51 | layername = 'FID_Inception_Net/pool_3:0' 52 | pool3 = sess.graph.get_tensor_by_name(layername) 53 | ops = pool3.graph.get_operations() 54 | for op_idx, op in enumerate(ops): 55 | for o in op.outputs: 56 | shape = o.get_shape() 57 | if shape._dims != []: 58 | shape = [s.value for s in shape] 59 | new_shape = [] 60 | for j, s in enumerate(shape): 61 | if s == 1 and j == 0: 62 | new_shape.append(None) 63 | else: 64 | new_shape.append(s) 65 | o.__dict__['_shape_val'] = tf.TensorShape(new_shape) 66 | return pool3 67 | #------------------------------------------------------------------------------- 68 | 69 | 70 | def get_activations(images, sess, batch_size=50, verbose=False): 71 | """Calculates the activations of the pool_3 layer for all images. 72 | 73 | Params: 74 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values 75 | must lie between 0 and 256. 76 | -- sess : current session 77 | -- batch_size : the images numpy array is split into batches with batch size 78 | batch_size. A reasonable batch size depends on the disposable hardware. 79 | -- verbose : If set to True and parameter out_step is given, the number of calculated 80 | batches is reported. 81 | Returns: 82 | -- A numpy array of dimension (num images, 2048) that contains the 83 | activations of the given tensor when feeding inception with the query tensor. 84 | """ 85 | inception_layer = _get_inception_layer(sess) 86 | d0 = images.shape[0] 87 | if batch_size > d0: 88 | print("warning: batch size is bigger than the data size. setting batch size to data size") 89 | batch_size = d0 90 | n_batches = d0//batch_size 91 | n_used_imgs = n_batches*batch_size 92 | pred_arr = np.empty((n_used_imgs,2048)) 93 | for i in range(n_batches): 94 | if verbose: 95 | print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True) 96 | start = i*batch_size 97 | end = start + batch_size 98 | batch = images[start:end] 99 | pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch}) 100 | pred_arr[start:end] = pred.reshape(batch_size,-1) 101 | if verbose: 102 | print(" done") 103 | return pred_arr 104 | #------------------------------------------------------------------------------- 105 | 106 | 107 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 108 | """Numpy implementation of the Frechet Distance. 109 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 110 | and X_2 ~ N(mu_2, C_2) is 111 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 112 | 113 | Stable version by Dougal J. Sutherland. 114 | 115 | Params: 116 | -- mu1 : Numpy array containing the activations of the pool_3 layer of the 117 | inception net ( like returned by the function 'get_predictions') 118 | for generated samples. 119 | -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted 120 | on an representive data set. 121 | -- sigma1: The covariance matrix over activations of the pool_3 layer for 122 | generated samples. 123 | -- sigma2: The covariance matrix over activations of the pool_3 layer, 124 | precalcualted on an representive data set. 125 | 126 | Returns: 127 | -- : The Frechet Distance. 128 | """ 129 | 130 | mu1 = np.atleast_1d(mu1) 131 | mu2 = np.atleast_1d(mu2) 132 | 133 | sigma1 = np.atleast_2d(sigma1) 134 | sigma2 = np.atleast_2d(sigma2) 135 | 136 | assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" 137 | assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" 138 | 139 | diff = mu1 - mu2 140 | 141 | # product might be almost singular 142 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 143 | if not np.isfinite(covmean).all(): 144 | msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps 145 | #warnings.warn(msg) 146 | offset = np.eye(sigma1.shape[0]) * eps 147 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 148 | 149 | # numerical error might give slight imaginary component 150 | if np.iscomplexobj(covmean): 151 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 152 | m = np.max(np.abs(covmean.imag)) 153 | raise ValueError("Imaginary component {}".format(m)) 154 | covmean = covmean.real 155 | 156 | tr_covmean = np.trace(covmean) 157 | 158 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 159 | #------------------------------------------------------------------------------- 160 | 161 | 162 | def calculate_activation_statistics(images, sess, batch_size=50, verbose=False): 163 | """Calculation of the statistics used by the FID. 164 | Params: 165 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values 166 | must lie between 0 and 255. 167 | -- sess : current session 168 | -- batch_size : the images numpy array is split into batches with batch size 169 | batch_size. A reasonable batch size depends on the available hardware. 170 | -- verbose : If set to True and parameter out_step is given, the number of calculated 171 | batches is reported. 172 | Returns: 173 | -- mu : The mean over samples of the activations of the pool_3 layer of 174 | the incption model. 175 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 176 | the incption model. 177 | """ 178 | act = get_activations(images, sess, batch_size, verbose) 179 | mu = np.mean(act, axis=0) 180 | sigma = np.cov(act, rowvar=False) 181 | return mu, sigma 182 | #------------------------------------------------------------------------------- 183 | 184 | 185 | #------------------------------------------------------------------------------- 186 | # The following functions aren't needed for calculating the FID 187 | # they're just here to make this module work as a stand-alone script 188 | # for calculating FID scores 189 | #------------------------------------------------------------------------------- 190 | def check_or_download_inception(inception_path): 191 | ''' Checks if the path to the inception file is valid, or downloads 192 | the file if it is not present. ''' 193 | INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 194 | if inception_path is None: 195 | inception_path = '/tmp' 196 | inception_path = pathlib.Path(inception_path) 197 | model_file = inception_path / 'classify_image_graph_def.pb' 198 | if not model_file.exists(): 199 | print("Downloading Inception model") 200 | from urllib import request 201 | import tarfile 202 | fn, _ = request.urlretrieve(INCEPTION_URL) 203 | with tarfile.open(fn, mode='r') as f: 204 | f.extract('classify_image_graph_def.pb', str(model_file.parent)) 205 | return str(model_file) 206 | 207 | 208 | def _handle_path(path, sess): 209 | if path.endswith('.npz'): 210 | f = np.load(path) 211 | m, s = f['mu'][:], f['sigma'][:] 212 | f.close() 213 | else: 214 | path = pathlib.Path(path) 215 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 216 | x = np.array([imread(str(fn)).astype(np.float32) for fn in files]) 217 | m, s = calculate_activation_statistics(x, sess, batch_size=args.batch_size) 218 | return m, s 219 | 220 | 221 | def calculate_fid_given_paths(paths, inception_path): 222 | ''' Calculates the FID of two paths. ''' 223 | inception_path = check_or_download_inception(inception_path) 224 | 225 | for p in paths: 226 | if not os.path.exists(p): 227 | raise RuntimeError("Invalid path: %s" % p) 228 | 229 | log_output = open(f"{args.output_dir}/log_FID.txt", 'a+') 230 | 231 | create_inception_graph(str(inception_path)) 232 | with tf.Session(config=config) as sess: 233 | sess.run(tf.global_variables_initializer()) 234 | m1, s1 = _handle_path(paths[0], sess) 235 | m2, s2 = _handle_path(paths[1], sess) 236 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 237 | s = '%s / Iter %06d : %.4f' % (args.output_name, args.at, fid_value) 238 | print(s, file=log_output) 239 | return fid_value 240 | 241 | 242 | if __name__ == "__main__": 243 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 244 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 245 | parser.add_argument("path", type=str, nargs=2, 246 | help='Path to the generated images or to .npz statistic files') 247 | parser.add_argument("-i", "--inception", type=str, default=None, 248 | help='Path to Inception model (will be downloaded if not provided)') 249 | parser.add_argument("--gpu", default="", type=str, 250 | help='GPU to use (leave blank for CPU only)') 251 | parser.add_argument("--output_dir", default="/network/home/jolicoea/Output/Extra", type=str, 252 | help='Directory to store logging results') 253 | parser.add_argument("--output_name", default="", type=str, 254 | help='Name of method used, to use when logging results') 255 | parser.add_argument("--at", default=1, type=int, 256 | help='Number of iteration, to use when logging results') 257 | parser.add_argument("--batch_size", default=50, type=int, 258 | help='Batch size') 259 | args = parser.parse_args() 260 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 261 | fid_value = calculate_fid_given_paths(args.path, args.inception) 262 | print("FID: ", fid_value) 263 | -------------------------------------------------------------------------------- /Code/fid_new.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ''' Calculates the Frechet Inception Distance (FID) to evalulate GANs. 3 | 4 | The FID metric calculates the distance between two distributions of images. 5 | Typically, we have summary statistics (mean & covariance matrix) of one 6 | of these distributions, while the 2nd distribution is given by a GAN. 7 | 8 | When run as a stand-alone program, it compares the distribution of 9 | images that are stored as PNG/JPEG at a specified location with a 10 | distribution given by summary statistics (in pickle format). 11 | 12 | The FID is calculated by assuming that X_1 and X_2 are the activations of 13 | the pool_3 layer of the inception net for generated samples and real world 14 | samples respectivly. 15 | 16 | See --help to see further details. 17 | ''' 18 | 19 | from __future__ import absolute_import, division, print_function 20 | import numpy as np 21 | import os 22 | import gzip, pickle 23 | import tensorflow as tf 24 | from scipy.misc import imread 25 | from scipy import linalg 26 | import pathlib 27 | import urllib 28 | 29 | 30 | class InvalidFIDException(Exception): 31 | pass 32 | 33 | 34 | def create_inception_graph(pth): 35 | """Creates a graph from saved GraphDef file.""" 36 | # Creates graph from saved graph_def.pb. 37 | with tf.gfile.FastGFile( pth, 'rb') as f: 38 | graph_def = tf.GraphDef() 39 | graph_def.ParseFromString( f.read()) 40 | _ = tf.import_graph_def( graph_def, name='FID_Inception_Net') 41 | #------------------------------------------------------------------------------- 42 | 43 | 44 | # code for handling inception net derived from 45 | # https://github.com/openai/improved-gan/blob/master/inception_score/model.py 46 | def _get_inception_layer(sess): 47 | """Prepares inception net for batched usage and returns pool_3 layer. """ 48 | layername = 'FID_Inception_Net/pool_3:0' 49 | pool3 = sess.graph.get_tensor_by_name(layername) 50 | ops = pool3.graph.get_operations() 51 | for op_idx, op in enumerate(ops): 52 | for o in op.outputs: 53 | shape = o.get_shape() 54 | if shape._dims != []: 55 | shape = [s.value for s in shape] 56 | new_shape = [] 57 | for j, s in enumerate(shape): 58 | if s == 1 and j == 0: 59 | new_shape.append(None) 60 | else: 61 | new_shape.append(s) 62 | o.__dict__['_shape_val'] = tf.TensorShape(new_shape) 63 | return pool3 64 | #------------------------------------------------------------------------------- 65 | 66 | 67 | def get_activations(images, sess, batch_size=50, verbose=False): 68 | """Calculates the activations of the pool_3 layer for all images. 69 | 70 | Params: 71 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values 72 | must lie between 0 and 256. 73 | -- sess : current session 74 | -- batch_size : the images numpy array is split into batches with batch size 75 | batch_size. A reasonable batch size depends on the disposable hardware. 76 | -- verbose : If set to True and parameter out_step is given, the number of calculated 77 | batches is reported. 78 | Returns: 79 | -- A numpy array of dimension (num images, 2048) that contains the 80 | activations of the given tensor when feeding inception with the query tensor. 81 | """ 82 | inception_layer = _get_inception_layer(sess) 83 | d0 = images.shape[0] 84 | if batch_size > d0: 85 | print("warning: batch size is bigger than the data size. setting batch size to data size") 86 | batch_size = d0 87 | n_batches = d0//batch_size 88 | n_used_imgs = n_batches*batch_size 89 | pred_arr = np.empty((n_used_imgs,2048)) 90 | for i in range(n_batches): 91 | if verbose: 92 | print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True) 93 | start = i*batch_size 94 | end = start + batch_size 95 | batch = images[start:end] 96 | pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch}) 97 | pred_arr[start:end] = pred.reshape(batch_size,-1) 98 | if verbose: 99 | print(" done") 100 | return pred_arr 101 | #------------------------------------------------------------------------------- 102 | 103 | 104 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 105 | """Numpy implementation of the Frechet Distance. 106 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 107 | and X_2 ~ N(mu_2, C_2) is 108 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 109 | 110 | Stable version by Dougal J. Sutherland. 111 | 112 | Params: 113 | -- mu1 : Numpy array containing the activations of the pool_3 layer of the 114 | inception net ( like returned by the function 'get_predictions') 115 | for generated samples. 116 | -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted 117 | on an representive data set. 118 | -- sigma1: The covariance matrix over activations of the pool_3 layer for 119 | generated samples. 120 | -- sigma2: The covariance matrix over activations of the pool_3 layer, 121 | precalcualted on an representive data set. 122 | 123 | Returns: 124 | -- : The Frechet Distance. 125 | """ 126 | 127 | mu1 = np.atleast_1d(mu1) 128 | mu2 = np.atleast_1d(mu2) 129 | 130 | sigma1 = np.atleast_2d(sigma1) 131 | sigma2 = np.atleast_2d(sigma2) 132 | 133 | assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" 134 | assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" 135 | 136 | diff = mu1 - mu2 137 | 138 | # product might be almost singular 139 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 140 | if not np.isfinite(covmean).all(): 141 | msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps 142 | warnings.warn(msg) 143 | offset = np.eye(sigma1.shape[0]) * eps 144 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 145 | 146 | # numerical error might give slight imaginary component 147 | if np.iscomplexobj(covmean): 148 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 149 | m = np.max(np.abs(covmean.imag)) 150 | raise ValueError("Imaginary component {}".format(m)) 151 | covmean = covmean.real 152 | 153 | tr_covmean = np.trace(covmean) 154 | 155 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 156 | #------------------------------------------------------------------------------- 157 | 158 | 159 | def calculate_activation_statistics(images, sess, batch_size=50, verbose=False): 160 | """Calculation of the statistics used by the FID. 161 | Params: 162 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values 163 | must lie between 0 and 255. 164 | -- sess : current session 165 | -- batch_size : the images numpy array is split into batches with batch size 166 | batch_size. A reasonable batch size depends on the available hardware. 167 | -- verbose : If set to True and parameter out_step is given, the number of calculated 168 | batches is reported. 169 | Returns: 170 | -- mu : The mean over samples of the activations of the pool_3 layer of 171 | the incption model. 172 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 173 | the incption model. 174 | """ 175 | act = get_activations(images, sess, batch_size, verbose) 176 | mu = np.mean(act, axis=0) 177 | sigma = np.cov(act, rowvar=False) 178 | return mu, sigma 179 | 180 | 181 | #------------------ 182 | # The following methods are implemented to obtain a batched version of the activations. 183 | # This has the advantage to reduce memory requirements, at the cost of slightly reduced efficiency. 184 | # - Pyrestone 185 | #------------------ 186 | 187 | 188 | def load_image_batch(files): 189 | """Convenience method for batch-loading images 190 | Params: 191 | -- files : list of paths to image files. Images need to have same dimensions for all files. 192 | Returns: 193 | -- A numpy array of dimensions (num_images,hi, wi, 3) representing the image pixel values. 194 | """ 195 | return np.array([imread(str(fn)).astype(np.float32) for fn in files]) 196 | 197 | def get_activations_from_files(files, sess, batch_size=50, verbose=False): 198 | """Calculates the activations of the pool_3 layer for all images. 199 | 200 | Params: 201 | -- files : list of paths to image files. Images need to have same dimensions for all files. 202 | -- sess : current session 203 | -- batch_size : the images numpy array is split into batches with batch size 204 | batch_size. A reasonable batch size depends on the disposable hardware. 205 | -- verbose : If set to True and parameter out_step is given, the number of calculated 206 | batches is reported. 207 | Returns: 208 | -- A numpy array of dimension (num images, 2048) that contains the 209 | activations of the given tensor when feeding inception with the query tensor. 210 | """ 211 | inception_layer = _get_inception_layer(sess) 212 | d0 = len(files) 213 | if batch_size > d0: 214 | print("warning: batch size is bigger than the data size. setting batch size to data size") 215 | batch_size = d0 216 | n_batches = d0//batch_size 217 | n_used_imgs = n_batches*batch_size 218 | pred_arr = np.empty((n_used_imgs,2048)) 219 | for i in range(n_batches): 220 | if verbose: 221 | print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True) 222 | start = i*batch_size 223 | end = start + batch_size 224 | batch = load_image_batch(files[start:end]) 225 | pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch}) 226 | pred_arr[start:end] = pred.reshape(batch_size,-1) 227 | del batch #clean up memory 228 | if verbose: 229 | print(" done") 230 | return pred_arr 231 | 232 | def calculate_activation_statistics_from_files(files, sess, batch_size=50, verbose=False): 233 | """Calculation of the statistics used by the FID. 234 | Params: 235 | -- files : list of paths to image files. Images need to have same dimensions for all files. 236 | -- sess : current session 237 | -- batch_size : the images numpy array is split into batches with batch size 238 | batch_size. A reasonable batch size depends on the available hardware. 239 | -- verbose : If set to True and parameter out_step is given, the number of calculated 240 | batches is reported. 241 | Returns: 242 | -- mu : The mean over samples of the activations of the pool_3 layer of 243 | the incption model. 244 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 245 | the incption model. 246 | """ 247 | act = get_activations_from_files(files, sess, batch_size, verbose) 248 | mu = np.mean(act, axis=0) 249 | sigma = np.cov(act, rowvar=False) 250 | return mu, sigma 251 | 252 | #------------------------------------------------------------------------------- 253 | 254 | 255 | #------------------------------------------------------------------------------- 256 | # The following functions aren't needed for calculating the FID 257 | # they're just here to make this module work as a stand-alone script 258 | # for calculating FID scores 259 | #------------------------------------------------------------------------------- 260 | def check_or_download_inception(inception_path): 261 | ''' Checks if the path to the inception file is valid, or downloads 262 | the file if it is not present. ''' 263 | INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 264 | if inception_path is None: 265 | inception_path = '/tmp' 266 | inception_path = pathlib.Path(inception_path) 267 | model_file = inception_path / 'classify_image_graph_def.pb' 268 | if not model_file.exists(): 269 | print("Downloading Inception model") 270 | from urllib import request 271 | import tarfile 272 | fn, _ = request.urlretrieve(INCEPTION_URL) 273 | with tarfile.open(fn, mode='r') as f: 274 | f.extract('classify_image_graph_def.pb', str(model_file.parent)) 275 | return str(model_file) 276 | 277 | 278 | def _handle_path(path, sess, low_profile=False): 279 | if path.endswith('.npz'): 280 | f = np.load(path) 281 | m, s = f['mu'][:], f['sigma'][:] 282 | f.close() 283 | else: 284 | path = pathlib.Path(path) 285 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 286 | if low_profile: 287 | m, s = calculate_activation_statistics_from_files(files, sess) 288 | else: 289 | x = np.array([imread(str(fn)).astype(np.float32) for fn in files]) 290 | m, s = calculate_activation_statistics(x, sess) 291 | del x #clean up memory 292 | return m, s 293 | 294 | 295 | def calculate_fid_given_paths(paths, inception_path, low_profile=False): 296 | ''' Calculates the FID of two paths. ''' 297 | inception_path = check_or_download_inception(inception_path) 298 | 299 | for p in paths: 300 | if not os.path.exists(p): 301 | raise RuntimeError("Invalid path: %s" % p) 302 | 303 | create_inception_graph(str(inception_path)) 304 | with tf.Session() as sess: 305 | sess.run(tf.global_variables_initializer()) 306 | m1, s1 = _handle_path(paths[0], sess, low_profile=low_profile) 307 | m2, s2 = _handle_path(paths[1], sess, low_profile=low_profile) 308 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 309 | return fid_value 310 | 311 | 312 | if __name__ == "__main__": 313 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 314 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 315 | parser.add_argument("path", type=str, nargs=2, 316 | help='Path to the generated images or to .npz statistic files') 317 | parser.add_argument("-i", "--inception", type=str, default=None, 318 | help='Path to Inception model (will be downloaded if not provided)') 319 | parser.add_argument("--gpu", default="", type=str, 320 | help='GPU to use (leave blank for CPU only)') 321 | parser.add_argument("--lowprofile", action="store_true", 322 | help='Keep only one batch of images in memory at a time. This reduces memory footprint, but may decrease speed slightly.') 323 | args = parser.parse_args() 324 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 325 | fid_value = calculate_fid_given_paths(args.path, args.inception, low_profile=args.lowprofile) 326 | print("FID: ", fid_value) 327 | -------------------------------------------------------------------------------- /Code/fid_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Arg1: Number of folders to check for images 3 | # Arg2: Name of approach to use when logging results 4 | # Arg3: Number of iteration used per set of images (Ex: 10k, 20k, 30k, 40k, 50k iterations have Arg1=5, Arg3=10000) 5 | # Arg4: Directory containing the images of the dataset 6 | 7 | current_iter=$3 8 | for i in $(seq 1 $1); do 9 | python fid.py "$SLURM_TMPDIR/Output/Extra/${i}" "${4}" -i "$SLURM_TMPDIR/models/Inception" --gpu "0" --output_name "${2}" --at $current_iter --output_dir "/scratch/jolicoea/Output/Extra/" 10 | current_iter=$((current_iter + $3)) 11 | done -------------------------------------------------------------------------------- /Code/preprocess_cat_dataset.py: -------------------------------------------------------------------------------- 1 | ## Modified version of https://github.com/microe/angora-blue/blob/master/cascade_training/describe.py by Erik Hovland 2 | ##### I modified it to work with Python 3, changed the paths and made it output a folder for images bigger than 64x64 and a folder for images bigger than 128x128 3 | 4 | import cv2 5 | import glob 6 | import math 7 | import sys 8 | 9 | def rotateCoords(coords, center, angleRadians): 10 | # Positive y is down so reverse the angle, too. 11 | angleRadians = -angleRadians 12 | xs, ys = coords[::2], coords[1::2] 13 | newCoords = [] 14 | n = min(len(xs), len(ys)) 15 | i = 0 16 | centerX = center[0] 17 | centerY = center[1] 18 | cosAngle = math.cos(angleRadians) 19 | sinAngle = math.sin(angleRadians) 20 | while i < n: 21 | xOffset = xs[i] - centerX 22 | yOffset = ys[i] - centerY 23 | newX = xOffset * cosAngle - yOffset * sinAngle + centerX 24 | newY = xOffset * sinAngle + yOffset * cosAngle + centerY 25 | newCoords += [newX, newY] 26 | i += 1 27 | return newCoords 28 | 29 | def preprocessCatFace(coords, image): 30 | 31 | leftEyeX, leftEyeY = coords[0], coords[1] 32 | rightEyeX, rightEyeY = coords[2], coords[3] 33 | mouthX = coords[4] 34 | if leftEyeX > rightEyeX and leftEyeY < rightEyeY and \ 35 | mouthX > rightEyeX: 36 | # The "right eye" is in the second quadrant of the face, 37 | # while the "left eye" is in the fourth quadrant (from the 38 | # viewer's perspective.) Swap the eyes' labels in order to 39 | # simplify the rotation logic. 40 | leftEyeX, rightEyeX = rightEyeX, leftEyeX 41 | leftEyeY, rightEyeY = rightEyeY, leftEyeY 42 | 43 | eyesCenter = (0.5 * (leftEyeX + rightEyeX), 44 | 0.5 * (leftEyeY + rightEyeY)) 45 | 46 | eyesDeltaX = rightEyeX - leftEyeX 47 | eyesDeltaY = rightEyeY - leftEyeY 48 | eyesAngleRadians = math.atan2(eyesDeltaY, eyesDeltaX) 49 | eyesAngleDegrees = eyesAngleRadians * 180.0 / math.pi 50 | 51 | # Straighten the image and fill in gray for blank borders. 52 | rotation = cv2.getRotationMatrix2D( 53 | eyesCenter, eyesAngleDegrees, 1.0) 54 | imageSize = image.shape[1::-1] 55 | straight = cv2.warpAffine(image, rotation, imageSize, 56 | borderValue=(128, 128, 128)) 57 | 58 | # Straighten the coordinates of the features. 59 | newCoords = rotateCoords( 60 | coords, eyesCenter, eyesAngleRadians) 61 | 62 | # Make the face as wide as the space between the ear bases. 63 | w = abs(newCoords[16] - newCoords[6]) 64 | # Make the face square. 65 | h = w 66 | # Put the center point between the eyes at (0.5, 0.4) in 67 | # proportion to the entire face. 68 | minX = eyesCenter[0] - w/2 69 | if minX < 0: 70 | w += minX 71 | minX = 0 72 | minY = eyesCenter[1] - h*2/5 73 | if minY < 0: 74 | h += minY 75 | minY = 0 76 | 77 | # Crop the face. 78 | crop = straight[int(minY):int(minY+h), int(minX):int(minX+w)] 79 | # Return the crop. 80 | return crop 81 | 82 | def describePositive(): 83 | output = open('log.txt', 'w') 84 | for imagePath in glob.glob('cat_dataset/*.jpg'): 85 | # Open the '.cat' annotation file associated with this 86 | # image. 87 | input = open('%s.cat' % imagePath, 'r') 88 | # Read the coordinates of the cat features from the 89 | # file. Discard the first number, which is the number 90 | # of features. 91 | coords = [int(i) for i in input.readline().split()[1:]] 92 | # Read the image. 93 | image = cv2.imread(imagePath) 94 | # Straighten and crop the cat face. 95 | crop = preprocessCatFace(coords, image) 96 | if crop is None: 97 | print >> sys.stderr, \ 98 | 'Failed to preprocess image at %s.' % \ 99 | imagePath 100 | continue 101 | # Save the crop to folders based on size 102 | h, w, colors = crop.shape 103 | if min(h,w) >= 32: 104 | Path1 = imagePath.replace("cat_dataset","cats_bigger_than_32x32") 105 | resized_crop = cv2.resize(crop, (32, 32)) 106 | cv2.imwrite(Path1, resized_crop) 107 | if min(h,w) >= 64: 108 | Path1 = imagePath.replace("cat_dataset","cats_bigger_than_64x64") 109 | resized_crop = cv2.resize(crop, (64, 64)) 110 | cv2.imwrite(Path1, resized_crop) 111 | if min(h,w) >= 128: 112 | Path2 = imagePath.replace("cat_dataset","cats_bigger_than_128x128") 113 | resized_crop = cv2.resize(crop, (128, 128)) 114 | cv2.imwrite(Path2, resized_crop) 115 | if min(h,w) >= 256: 116 | Path2 = imagePath.replace("cat_dataset","cats_bigger_than_256x256") 117 | resized_crop = cv2.resize(crop, (256, 256)) 118 | cv2.imwrite(Path2, resized_crop) 119 | # Append the cropped face and its bounds to the 120 | # positive description. 121 | #h, w = crop.shape[:2] 122 | #print (cropPath, 1, 0, 0, w, h, file=output) 123 | 124 | 125 | def main(): 126 | describePositive() 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /Code/pytorch_visualize.py: -------------------------------------------------------------------------------- 1 | # Source: https://gist.github.com/hyqneuron/caaf6087162f1c6361571ee489da260f 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import Parameter 6 | from torch.autograd import Variable, Function 7 | from collections import defaultdict 8 | import graphviz 9 | from graphviz import Digraph 10 | 11 | """ 12 | This is a rather distorted implementation of graph visualization in PyTorch. 13 | This implementation is distorted because PyTorch's autograd is undergoing refactoring right now. 14 | - neither func.next_functions nor func.previous_functions can be relied upon 15 | - BatchNorm's C backend does not follow the python Function interface 16 | - I'm not even sure whether to use var.creator or var.grad_fn (apparently the source tree and wheel builds use different 17 | interface now) 18 | As a result, we are forced to manually trace the graph, using 2 redundant mechanisms: 19 | - Function.__call__: this allows us to trace all Function creations. Function corresponds to Op in TF 20 | - Module.forward_hook: this is needed because the above method doesn't work for BatchNorm, as the current C backend does 21 | not follow the Python Function interface. 22 | To do graph visualization, follow these steps: 23 | 1. register hooks on model: register_vis_hooks(model) 24 | 2. pass data through model: output = model(input) 25 | 3. remove hooks : remove_vis_hooks() 26 | 4. perform visualization : save_visualization(name, format='svg') # name is a string without extension 27 | """ 28 | 29 | 30 | old_function__call__ = Function.__call__ 31 | 32 | def register_creator(inputs, creator, output): 33 | """ 34 | In the forward pass, our Function.__call__ and BatchNorm.forward_hook both call this method to register the creators 35 | inputs: list of input variables 36 | creator: one of 37 | - Function 38 | - BatchNorm module 39 | output: a single output variable 40 | """ 41 | cid = id(creator) 42 | oid = id(output) 43 | if oid in vars: 44 | return 45 | # connect creator to input 46 | for input in inputs: 47 | iid = id(input) 48 | func_trace[cid][iid] = input 49 | # register input 50 | vars[iid] = input 51 | # connect output to creator 52 | assert type(output) not in [tuple, list, dict] 53 | var_trace[oid][cid] = creator 54 | # register creator and output and all inputs 55 | vars[oid] = output 56 | funcs[cid] = creator 57 | 58 | hooks = [] 59 | 60 | def register_vis_hooks(model): 61 | global var_trace, func_trace, vars, funcs 62 | remove_vis_hooks() 63 | var_trace = defaultdict(lambda: {}) # map oid to {cid:creator} 64 | func_trace = defaultdict(lambda: {}) # map cid to {iid:input} 65 | vars = {} # map vid to Variable/Parameter 66 | funcs = {} # map cid to Function/BatchNorm module 67 | hooks = [] # contains the forward hooks, needed for hook removal 68 | 69 | def hook_func(module, inputs, output): 70 | assert 'BatchNorm' in mod.__class__.__name__ # batchnorms don't have shared superclass 71 | inputs = list(inputs) 72 | for p in [module.weight, module.bias]: 73 | if p is not None: 74 | inputs.append(p) 75 | register_creator(inputs, module, output) 76 | 77 | for mod in model.modules(): 78 | if 'BatchNorm' in mod.__class__.__name__: # batchnorms don't have shared superclass 79 | hook = mod.register_forward_hook(hook_func) 80 | hooks.append(hook) 81 | 82 | def new_function__call__(self, *args, **kwargs): 83 | inputs = [a for a in args if isinstance(a, Variable)] 84 | inputs += [a for a in kwargs.values() if isinstance(a, Variable)] 85 | output = old_function__call__(self, *args, **kwargs) 86 | register_creator(inputs, self, output) 87 | return output 88 | 89 | Function.__call__ = new_function__call__ 90 | 91 | 92 | def remove_vis_hooks(): 93 | for hook in hooks: 94 | hook.remove() 95 | 96 | Function.__call__ = old_function__call__ 97 | 98 | 99 | def save_visualization(name, format='svg'): 100 | g = graphviz.Digraph(format=format) 101 | def sizestr(var): 102 | size = [int(i) for i in list(var.size())] 103 | return str(size) 104 | # add variable nodes 105 | for vid, var in vars.iteritems(): 106 | if isinstance(var, nn.Parameter): 107 | g.node(str(vid), label=sizestr(var), shape='ellipse', style='filled', fillcolor='red') 108 | elif isinstance(var, Variable): 109 | g.node(str(vid), label=sizestr(var), shape='ellipse', style='filled', fillcolor='lightblue') 110 | else: 111 | assert False, var.__class__ 112 | # add creator nodes 113 | for cid in func_trace: 114 | creator = funcs[cid] 115 | g.node(str(cid), label=str(creator.__class__.__name__), shape='rectangle', style='filled', fillcolor='orange') 116 | # add edges between creator and inputs 117 | for cid in func_trace: 118 | for iid in func_trace[cid]: 119 | g.edge(str(iid), str(cid)) 120 | # add edges between outputs and creators 121 | for oid in var_trace: 122 | for cid in var_trace[oid]: 123 | g.edge(str(cid), str(oid)) 124 | g.render(name) 125 | 126 | from graphviz import Digraph 127 | import torch 128 | from torch.autograd import Variable 129 | 130 | 131 | def make_dot(var, params): 132 | """ Produces Graphviz representation of PyTorch autograd graph 133 | 134 | Blue nodes are the Variables that require grad, orange are Tensors 135 | saved for backward in torch.autograd.Function 136 | 137 | Args: 138 | var: output Variable 139 | params: dict of (name, Variable) to add names to node that 140 | require grad (TODO: make optional) 141 | """ 142 | param_map = {id(v): k for k, v in params.items()} 143 | print(param_map) 144 | 145 | node_attr = dict(style='filled', 146 | shape='box', 147 | align='left', 148 | fontsize='12', 149 | ranksep='0.1', 150 | height='0.2') 151 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) 152 | seen = set() 153 | 154 | def size_to_str(size): 155 | return '('+(', ').join(['%d'% v for v in size])+')' 156 | 157 | def add_nodes(var): 158 | if var not in seen: 159 | if torch.is_tensor(var): 160 | dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange') 161 | elif hasattr(var, 'variable'): 162 | u = var.variable 163 | node_name = '%s\n %s' % (param_map.get(id(u)), size_to_str(u.size())) 164 | dot.node(str(id(var)), node_name, fillcolor='lightblue') 165 | else: 166 | dot.node(str(id(var)), str(type(var).__name__)) 167 | seen.add(var) 168 | if hasattr(var, 'next_functions'): 169 | for u in var.next_functions: 170 | if u[0] is not None: 171 | dot.edge(str(id(u[0])), str(id(var))) 172 | add_nodes(u[0]) 173 | if hasattr(var, 'saved_tensors'): 174 | for t in var.saved_tensors: 175 | dot.edge(str(id(t)), str(id(var))) 176 | add_nodes(t) 177 | add_nodes(var.grad_fn) 178 | return dot -------------------------------------------------------------------------------- /Code/setting_up_script.sh: -------------------------------------------------------------------------------- 1 | ## Requirements 2 | # pip install opencv-python 3 | # Install PyTorch from source 4 | ## Optional 5 | # Install Tensorflow from source (If you want to be able to calculate the FID) 6 | # Install the latest version of R (If you want to recreate the plots of the bias like in the paper) 7 | 8 | ###### Important note: If downloading from archive doesn't work, download it from http://academictorrents.com/details/c501571c29d16d7f41d159d699d0e7fb37092cbd (and then unzip the 6 zip files) 9 | ###### Try to still download https://archive.org/download/CAT_DATASET/00000003_015.jpg.cat, I don't know if it's corrected in the torrent 10 | 11 | ## Download CAT dataset from site 12 | wget -nc https://archive.org/download/CAT_DATASET/CAT_DATASET_01.zip 13 | wget -nc https://archive.org/download/CAT_DATASET/CAT_DATASET_02.zip 14 | wget -nc https://archive.org/download/CAT_DATASET/00000003_015.jpg.cat 15 | 16 | ## Setting up folder () 17 | unzip CAT_DATASET_01.zip -d cat_dataset 18 | unzip CAT_DATASET_02.zip -d cat_dataset 19 | mv cat_dataset/CAT_00/* cat_dataset 20 | rmdir cat_dataset/CAT_00 21 | mv cat_dataset/CAT_01/* cat_dataset 22 | rmdir cat_dataset/CAT_01 23 | mv cat_dataset/CAT_02/* cat_dataset 24 | rmdir cat_dataset/CAT_02 25 | mv cat_dataset/CAT_03/* cat_dataset 26 | rmdir cat_dataset/CAT_03 27 | mv cat_dataset/CAT_04/* cat_dataset 28 | rmdir cat_dataset/CAT_04 29 | mv cat_dataset/CAT_05/* cat_dataset 30 | rmdir cat_dataset/CAT_05 31 | mv cat_dataset/CAT_06/* cat_dataset 32 | rmdir cat_dataset/CAT_06 33 | 34 | ## Error correction 35 | rm cat_dataset/00000003_019.jpg.cat 36 | mv 00000003_015.jpg.cat cat_dataset/00000003_015.jpg.cat 37 | 38 | ## Removing outliers 39 | # Corrupted, drawings, badly cropped, inverted, impossible to tell it's a cat, blocked face 40 | cd cat_dataset 41 | rm 00000004_007.jpg 00000007_002.jpg 00000045_028.jpg 00000050_014.jpg 00000056_013.jpg 00000059_002.jpg 00000108_005.jpg 00000122_023.jpg 00000126_005.jpg 00000132_018.jpg 00000142_024.jpg 00000142_029.jpg 00000143_003.jpg 00000145_021.jpg 00000166_021.jpg 00000169_021.jpg 00000186_002.jpg 00000202_022.jpg 00000208_023.jpg 00000210_003.jpg 00000229_005.jpg 00000236_025.jpg 00000249_016.jpg 00000254_013.jpg 00000260_019.jpg 00000261_029.jpg 00000265_029.jpg 00000271_020.jpg 00000282_026.jpg 00000316_004.jpg 00000352_014.jpg 00000400_026.jpg 00000406_006.jpg 00000431_024.jpg 00000443_027.jpg 00000502_015.jpg 00000504_012.jpg 00000510_019.jpg 00000514_016.jpg 00000514_008.jpg 00000515_021.jpg 00000519_015.jpg 00000522_016.jpg 00000523_021.jpg 00000529_005.jpg 00000556_022.jpg 00000574_011.jpg 00000581_018.jpg 00000582_011.jpg 00000588_016.jpg 00000588_019.jpg 00000590_006.jpg 00000592_018.jpg 00000593_027.jpg 00000617_013.jpg 00000618_016.jpg 00000619_025.jpg 00000622_019.jpg 00000622_021.jpg 00000630_007.jpg 00000645_016.jpg 00000656_017.jpg 00000659_000.jpg 00000660_022.jpg 00000660_029.jpg 00000661_016.jpg 00000663_005.jpg 00000672_027.jpg 00000673_027.jpg 00000675_023.jpg 00000692_006.jpg 00000800_017.jpg 00000805_004.jpg 00000807_020.jpg 00000823_010.jpg 00000824_010.jpg 00000836_008.jpg 00000843_021.jpg 00000850_025.jpg 00000862_017.jpg 00000864_007.jpg 00000865_015.jpg 00000870_007.jpg 00000877_014.jpg 00000882_013.jpg 00000887_028.jpg 00000893_022.jpg 00000907_013.jpg 00000921_029.jpg 00000929_022.jpg 00000934_006.jpg 00000960_021.jpg 00000976_004.jpg 00000987_000.jpg 00000993_009.jpg 00001006_014.jpg 00001008_013.jpg 00001012_019.jpg 00001014_005.jpg 00001020_017.jpg 00001039_008.jpg 00001039_023.jpg 00001048_029.jpg 00001057_003.jpg 00001068_005.jpg 00001113_015.jpg 00001140_007.jpg 00001157_029.jpg 00001158_000.jpg 00001167_007.jpg 00001184_007.jpg 00001188_019.jpg 00001204_027.jpg 00001205_022.jpg 00001219_005.jpg 00001243_010.jpg 00001261_005.jpg 00001270_028.jpg 00001274_006.jpg 00001293_015.jpg 00001312_021.jpg 00001365_026.jpg 00001372_006.jpg 00001379_018.jpg 00001388_024.jpg 00001389_026.jpg 00001418_028.jpg 00001425_012.jpg 00001431_001.jpg 00001456_018.jpg 00001458_003.jpg 00001468_019.jpg 00001475_009.jpg 00001487_020.jpg 42 | rm 00000004_007.jpg.cat 00000007_002.jpg.cat 00000045_028.jpg.cat 00000050_014.jpg.cat 00000056_013.jpg.cat 00000059_002.jpg.cat 00000108_005.jpg.cat 00000122_023.jpg.cat 00000126_005.jpg.cat 00000132_018.jpg.cat 00000142_024.jpg.cat 00000142_029.jpg.cat 00000143_003.jpg.cat 00000145_021.jpg.cat 00000166_021.jpg.cat 00000169_021.jpg.cat 00000186_002.jpg.cat 00000202_022.jpg.cat 00000208_023.jpg.cat 00000210_003.jpg.cat 00000229_005.jpg.cat 00000236_025.jpg.cat 00000249_016.jpg.cat 00000254_013.jpg.cat 00000260_019.jpg.cat 00000261_029.jpg.cat 00000265_029.jpg.cat 00000271_020.jpg.cat 00000282_026.jpg.cat 00000316_004.jpg.cat 00000352_014.jpg.cat 00000400_026.jpg.cat 00000406_006.jpg.cat 00000431_024.jpg.cat 00000443_027.jpg.cat 00000502_015.jpg.cat 00000504_012.jpg.cat 00000510_019.jpg.cat 00000514_016.jpg.cat 00000514_008.jpg.cat 00000515_021.jpg.cat 00000519_015.jpg.cat 00000522_016.jpg.cat 00000523_021.jpg.cat 00000529_005.jpg.cat 00000556_022.jpg.cat 00000574_011.jpg.cat 00000581_018.jpg.cat 00000582_011.jpg.cat 00000588_016.jpg.cat 00000588_019.jpg.cat 00000590_006.jpg.cat 00000592_018.jpg.cat 00000593_027.jpg.cat 00000617_013.jpg.cat 00000618_016.jpg.cat 00000619_025.jpg.cat 00000622_019.jpg.cat 00000622_021.jpg.cat 00000630_007.jpg.cat 00000645_016.jpg.cat 00000656_017.jpg.cat 00000659_000.jpg.cat 00000660_022.jpg.cat 00000660_029.jpg.cat 00000661_016.jpg.cat 00000663_005.jpg.cat 00000672_027.jpg.cat 00000673_027.jpg.cat 00000675_023.jpg.cat 00000692_006.jpg.cat 00000800_017.jpg.cat 00000805_004.jpg.cat 00000807_020.jpg.cat 00000823_010.jpg.cat 00000824_010.jpg.cat 00000836_008.jpg.cat 00000843_021.jpg.cat 00000850_025.jpg.cat 00000862_017.jpg.cat 00000864_007.jpg.cat 00000865_015.jpg.cat 00000870_007.jpg.cat 00000877_014.jpg.cat 00000882_013.jpg.cat 00000887_028.jpg.cat 00000893_022.jpg.cat 00000907_013.jpg.cat 00000921_029.jpg.cat 00000929_022.jpg.cat 00000934_006.jpg.cat 00000960_021.jpg.cat 00000976_004.jpg.cat 00000987_000.jpg.cat 00000993_009.jpg.cat 00001006_014.jpg.cat 00001008_013.jpg.cat 00001012_019.jpg.cat 00001014_005.jpg.cat 00001020_017.jpg.cat 00001039_008.jpg.cat 00001039_023.jpg.cat 00001048_029.jpg.cat 00001057_003.jpg.cat 00001068_005.jpg.cat 00001113_015.jpg.cat 00001140_007.jpg.cat 00001157_029.jpg.cat 00001158_000.jpg.cat 00001167_007.jpg.cat 00001184_007.jpg.cat 00001188_019.jpg.cat 00001204_027.jpg.cat 00001205_022.jpg.cat 00001219_005.jpg.cat 00001243_010.jpg.cat 00001261_005.jpg.cat 00001270_028.jpg.cat 00001274_006.jpg.cat 00001293_015.jpg.cat 00001312_021.jpg.cat 00001365_026.jpg.cat 00001372_006.jpg.cat 00001379_018.jpg.cat 00001388_024.jpg.cat 00001389_026.jpg.cat 00001418_028.jpg.cat 00001425_012.jpg.cat 00001431_001.jpg.cat 00001456_018.jpg.cat 00001458_003.jpg.cat 00001468_019.jpg.cat 00001475_009.jpg.cat 00001487_020.jpg.cat 43 | cd .. 44 | 45 | ## Preprocessing and putting in folders for different image sizes 46 | mkdir cats_bigger_than_32x32 47 | mkdir cats_bigger_than_64x64 48 | mkdir cats_bigger_than_128x128 49 | wget -nc https://raw.githubusercontent.com/AlexiaJM/Relativistic-f-divergences/master/preprocess_cat_dataset.py 50 | python preprocess_cat_dataset.py 51 | 52 | ## Removing cat_dataset 53 | rm -r cat_dataset 54 | 55 | ## Move to your favorite place 56 | #mv cats_bigger_than_32x32 /home/alexia/Datasets/Meow_32x32 57 | #mv cats_bigger_than_64x64 /home/alexia/Datasets/Meow_64x64 58 | #mv cats_bigger_than_128x128 /home/alexia/Datasets/Meow_128x128 59 | 60 | ## Create FID stats 61 | # Change to your folders 62 | 63 | # CIFAR-10 64 | # Note that I actually used the one in http://bioinf.jku.at/research/ttur/, but either way should be the same 65 | python create_FID_stats.py --output_path '/home/alexia/fid_stats/CIFAR10_fid_stats32.npz' 66 | 67 | # CAT 32x32 68 | python create_FID_stats.py --data_path '/home/alexia/Datasets/Meow_32x32/cats_bigger_than_32x32' --output_path '/home/alexia/fid_stats/CAT_fid_stats32.npz' 69 | 70 | # CelebA 32x32 71 | python preprocess_dataset.py --centercrop 160 --image_size 32 --input_path '/home/alexia/Datasets/CelebA/img_align_celeba' --output_path '/home/alexia/Datasets/CelebA_transformed32' 72 | python create_FID_stats.py --data_path '/home/alexia/Datasets/CelebA_transformed32' --output_path '/home/alexia/fid_stats/CelebA_fid_stats32.npz' -------------------------------------------------------------------------------- /Code/startup_tmp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## Put datasets in node so it will be faster than reading from network. 4 | ## Assuming /project/rpp-bengioy/jolicoea/ -> /scratch/jolicoea/Datasets/ 5 | ## Example: 6 | # bash startup.sh dir1="CIFAR10" dir2="Meow_64x64" dir3="Meow_128x128" dir4="" 7 | 8 | # Arg1: Input_folder1 9 | # Arg2: Input_folder2 10 | # Arg3: Input_folder3 11 | # Arg4: Input_folder4 12 | 13 | for ARGUMENT in "$@" 14 | do 15 | 16 | KEY=$(echo $ARGUMENT | cut -f1 -d=) 17 | VALUE=$(echo $ARGUMENT | cut -f2 -d=) 18 | 19 | case "$KEY" in 20 | dir1) dir1=${VALUE} ;; 21 | dir2) dir2=${VALUE} ;; 22 | dir3) dir3=${VALUE} ;; 23 | dir4) dir4=${VALUE} ;; 24 | *) 25 | esac 26 | 27 | done 28 | 29 | echo "Setting up fid_stats" 30 | mkdir -p "$SLURM_TMPDIR/Datasets" 31 | cp -r -n "/scratch/jolicoea/fid_stats" "$SLURM_TMPDIR" 32 | 33 | if [ -z "$dir1" ]; 34 | then 35 | echo "Empty directory 1" 36 | 37 | else 38 | echo "Setting up directory $dir1" 39 | mkdir -p "$SLURM_TMPDIR/Datasets/$dir1" && tar xzf "/project/rpp-bengioy/jolicoea/Datasets/$dir1.tar.gz" -C "$SLURM_TMPDIR/Datasets" 40 | fi 41 | if [ -z "$dir2" ]; 42 | then 43 | echo "Empty directory 2" 44 | 45 | else 46 | echo "Setting up directory $dir2" 47 | mkdir -p "$SLURM_TMPDIR/Datasets/$dir2" && tar xzf "/project/rpp-bengioy/jolicoea/Datasets/$dir2.tar.gz" -C "$SLURM_TMPDIR/Datasets" 48 | fi 49 | if [ -z "$dir3" ]; 50 | then 51 | echo "Empty directory 3" 52 | 53 | else 54 | echo "Setting up directory $dir3" 55 | mkdir -p "$SLURM_TMPDIR/Datasets/$dir3" && tar xzf "/project/rpp-bengioy/jolicoea/Datasets/$dir3.tar.gz" -C "$SLURM_TMPDIR/Datasets" 56 | fi 57 | if [ -z "$dir4" ]; 58 | then 59 | echo "Empty directory 4" 60 | 61 | else 62 | echo "Setting up directory $dir4" 63 | mkdir -p "$SLURM_TMPDIR/Datasets/$dir4" && tar xzf "/project/rpp-bengioy/jolicoea/Datasets/$dir4.tar.gz" -C "$SLURM_TMPDIR/Datasets" 64 | fi 65 | 66 | # Make local directories in tempdir 67 | mkdir -p $SLURM_TMPDIR/Output/Extra/ 68 | 69 | # Transfer Inception model 70 | mkdir -p $SLURM_TMPDIR/models/ 71 | cp -r -n /scratch/jolicoea/models/Inception $SLURM_TMPDIR/models/ 72 | 73 | # Export directory 74 | export SLURM_TMPDIR -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Alexia Jolicoeur-Martineau 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MaximumMarginGANs 2 | Code for paper: [Support Vector Machines, Wasserstein's distance and gradient-penalty GANs maximize a margin](https://arxiv.org/abs/1910.06922) 3 | 4 | **Discussion at https://ajolicoeur.wordpress.com/MaximumMarginGANs.** 5 | 6 | This basically the same code as https://github.com/AlexiaJM/relativistic-f-divergences, but with more options. 7 | 8 | ## Citation 9 | 10 | If you use our novel gradient penalties or would like to mention that gradient penalties correspond to having a maximum-margin discriminator, please cite us in your work: 11 | ``` 12 | @article{jolicoeur2019connections} 13 | title={Connections between Support Vector Machines, Wasserstein distance and gradient-penalty GANs}, 14 | author={Jolicoeur-Martineau, Alexia}, 15 | journal={arXiv preprint arXiv:1910.06922}, 16 | year={2019} 17 | } 18 | ``` 19 | 20 | 21 | **Sample PyTorch code to use L1, L2, Linfinity gradient penalties with hinge or LS:** 22 | 23 | ```python 24 | 25 | # Best setting (novel Hinge Linfinity gradient penalty) 26 | grad_penalty_Lp_norm = 'Linf' 27 | penalty_type = 'hinge' 28 | 29 | # Default setting from WGAN-GP and most cases (L2 gradient penalty) 30 | grad_penalty_Lp_norm = 'L2' 31 | penalty_type = 'LS' 32 | 33 | # Calculate gradient 34 | penalty = 20 # 10 is the more usual choice 35 | u.resize_(batch_size, 1, 1, 1) 36 | u.uniform_(0, 1) 37 | x_both = x.data*u + x_fake.data*(1-u) # interpolation between real and fake samples 38 | x_both = x_both.cuda() 39 | x_both = Variable(x_both, requires_grad=True) 40 | y0 = D(x_both) 41 | grad = torch.autograd.grad(outputs=y0, inputs=x_both, grad_outputs=grad_outputs, retain_graph=True, 42 | create_graph=True, only_inputs=True)[0] 43 | x_both.requires_grad_(False) 44 | grad = grad.view(current_batch_size,-1) 45 | 46 | if grad_penalty_Lp_norm = 'Linf': # Linfinity gradient norm penalty (Corresponds to L1 margin, BEST results) 47 | grad_abs = torch.abs(grad) # Absolute value of gradient 48 | grad_norm , _ = torch.max(grad_abs,1) 49 | elif grad_penalty_Lp_norm = 'L1': # L1 gradient norm penalty (Corresponds to Linfinity margin, WORST results) 50 | grad_norm = grad.norm(1,1) 51 | else: # L2 gradient norm penalty (Corresponds to L2 margin, this is what people generally use) 52 | grad_norm = grad.norm(2,1) 53 | 54 | if penalty_type == 'LS': # The usual choice, penalize values below 1 and above 1 (too constraining to properly estimate the Wasserstein distance) 55 | constraint = (grad_norm-1).pow(2) 56 | elif penalty_type == 'hinge': # Penalize values above 1 only (best choice) 57 | constraint = torch.nn.ReLU()(grad_norm - 1) 58 | 59 | constraint = constraint.mean() 60 | grad_penalty = penalty*constraint 61 | grad_penalty.backward(retain_graph=True) 62 | ``` 63 | 64 | **Needed** 65 | 66 | * Python 3.6 67 | * Pytorch (Latest from source) 68 | * Tensorflow (Latest from source, needed to get FID) 69 | * Cat Dataset (http://academictorrents.com/details/c501571c29d16d7f41d159d699d0e7fb37092cbd) 70 | 71 | **To do beforehand** 72 | 73 | * Change all folders locations in GAN.py (and startup_tmp.sh, fid_script.sh, experiments.sh if you want FID and replication of the paper) 74 | * Make sure that there are existing folders at the locations you used 75 | * To get the CAT dataset: open and run each necessary lines of setting_up_script.sh in same folder as preprocess_cat_dataset.py (It will automatically download the cat datasets, if this doesn't work well download it from http://academictorrents.com/details/c501571c29d16d7f41d159d699d0e7fb37092cbd) 76 | 77 | **To run models** 78 | * HingeGAN Linfinity grad norm penalty with max(0, ||grad||-1): 79 | * python GAN.py --loss_D 3 --image_size 32 --CIFAR10 True --grad_penalty True --l1_margin --penalty-type 'hinge' 80 | * WGAN Linfinity grad norm penalty with max(0, ||grad||-1): 81 | * python GAN.py --loss_D 4 --image_size 32 --CIFAR10 True --grad_penalty True --l1_margin --penalty-type 'hinge' 82 | * WGAN L2 grad norm penalty with (||grad||-1)^2 (i.e., WGAN-GP): 83 | * python GAN.py --loss_D 4 --image_size 32 --CIFAR10 True --grad_penalty True 84 | 85 | **To replicate the paper** 86 | * Open experiments.sh and run the lines you want 87 | --------------------------------------------------------------------------------