├── models ├── __init__.py ├── NormalizedSoftplus.py ├── Matsushita.py ├── LeastSquare.py ├── Matsushita_test.py ├── mlp.py ├── gan.py └── dcgan.py ├── README.md ├── fileUtil.py ├── kde.py ├── plot_results.py ├── LICENSE └── main.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/NormalizedSoftplus.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | import torch 7 | import math 8 | 9 | class NSoftPlus(torch.nn.Module): 10 | def __init__(self): 11 | super(NSoftPlus, self).__init__() 12 | self.log2 = math.log(2) 13 | 14 | def forward(self, x): 15 | return torch.log(torch.exp(x).add(1))/self.log2 - 1.0 16 | -------------------------------------------------------------------------------- /models/Matsushita.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | import torch 7 | 8 | class MatsushitaTransform(torch.nn.Module): 9 | def __init__(self): 10 | super(MatsushitaTransform, self).__init__() 11 | 12 | def forward(self, x): 13 | return (x + torch.sqrt(1 + torch.pow(x, 2))) * 0.5 14 | 15 | 16 | class MatsushitaTransformOne(torch.nn.Module): 17 | def __init__(self): 18 | super(MatsushitaTransformOne, self).__init__() 19 | 20 | def forward(self, x): 21 | return x + torch.sqrt(1 + torch.pow(x, 2)) - 1 22 | 23 | 24 | class MatsushitaLinkFunc(torch.nn.Module): 25 | def __init__(self): 26 | super(MatsushitaLinkFunc, self).__init__() 27 | 28 | def forward(self, x): 29 | return 0.5 * (1 + 0.5 * x / torch.sqrt(0.25 * torch.pow(x, 2) + 1)) 30 | 31 | 32 | class NormalizedMatsushita(torch.nn.Module): 33 | def __init__(self, mu=0): 34 | super(NormalizedMatsushita, self).__init__() 35 | self.mu = mu 36 | self.one_minus_mu = 1 - self.mu 37 | self.one_minus_mu_square = self.one_minus_mu * self.one_minus_mu 38 | 39 | def forward(self, x): 40 | return 0.5 * (x + torch.sqrt(self.one_minus_mu_square + torch.pow(x, 2)) - self.one_minus_mu) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # f-GANs in an Information Geometric Nutshell 2 | 3 | Pytorch implementation of [f-GANs in an Information Geometric Nutshell](http://arxiv.org/abs/) 4 | 5 | ## Prerequisites 6 | 7 | - Python 2.7 8 | - [Pytorch 0.1.12](http://pytorch.org) 9 | - [numpy 1.12.1](http://www.numpy.org/) 10 | 11 | 12 | ## Usage 13 | 14 | Put both mnist and lsun in a folder called DATA_ROOT. Download lsun with https://github.com/fyu/lsun. MNIST will be downloaded automatically in the first run. 15 | 16 | $ python download.py -o -c tower 17 | 18 | 19 | Assume all experimental results are put in EXPERIMENTAL_RESULTS. 20 | 21 | Evaluate a feedforward network with wasserstein GAN loss and mu-ReLU as the activation of hidden layers of the generator: 22 | 23 | $ python main.py --dataset mnist --dataroot --cuda -D wgan -A mlp -H murelu --experiment --task mu 24 | 25 | Evaluate DCGAN with GAN as the loss, and mu-ReLU as the activation of hidden layers of the generator: 26 | 27 | $ python main.py --dataset lsun --subset tower --dataroot --cuda -D gan -A dcgan -H murelu --experiment --task mu 28 | 29 | 30 | ## Author 31 | 32 | Lizhen Qu / [@qulizhen](https://cecs.anu.edu.au/people/lizhen-qu) -------------------------------------------------------------------------------- /models/LeastSquare.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # In[ ]: 4 | 5 | import torch 6 | from torch.autograd.function import Function 7 | 8 | class LeastSquareFunc(Function): 9 | 10 | def forward(self, input): 11 | mask_le_mone = input.le(-1).type_as(input) 12 | self.mask_ge_one = input.ge(1).type_as(input) 13 | index_gt_mone = input.gt(-1).type_as(input) 14 | index_lt_one = input.lt(1).type_as(input) 15 | self.mask_mone_one = index_lt_one * index_gt_mone 16 | mone = input.new().resize_as_(input).fill_(-1) 17 | mone= mone * mask_le_mone 18 | between_one = torch.pow(1 + input, 2) -1 19 | between_one = between_one * self.mask_mone_one 20 | ge_one = input * 4 - 1 21 | ge_one = ge_one * self.mask_ge_one 22 | between_one = mone + between_one 23 | ge_one = between_one + ge_one 24 | self.input = input 25 | return ge_one 26 | 27 | def backward(self, grad_output): 28 | grad_between = (self.input * 2 + 2) * self.mask_mone_one 29 | grad_ge_one = 4 * self.mask_ge_one 30 | grad_input = grad_output * (grad_between + grad_ge_one) 31 | 32 | return grad_input 33 | 34 | 35 | 36 | 37 | 38 | class LeastSquare(torch.nn.Module): 39 | def __init__(self): 40 | super(LeastSquare, self).__init__() 41 | 42 | def forward(self, input): 43 | return LeastSquareFunc()(input) 44 | -------------------------------------------------------------------------------- /fileUtil.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | if sys.version_info[0] == 3: 5 | import pickle 6 | else: 7 | import cPickle as pickle 8 | 9 | 10 | import errno 11 | 12 | 13 | class FileUtil: 14 | 15 | 16 | @staticmethod 17 | def validate_file(file_path): 18 | """ 19 | Method to ensure that the file and exists at the path 20 | If it doesn't exist it will create a blank file and 21 | the necessary folders at the path provided. 22 | :param file_path: Path to the file to be validated 23 | :return: None 24 | """ 25 | 26 | if not os.path.isfile(file_path): 27 | try: 28 | # Create the directories if they do not alread exist 29 | os.makedirs(os.path.dirname(file_path)) 30 | except OSError as exc: # Guard against race condition 31 | if exc.errno != errno.EEXIST: 32 | raise 33 | file = open(file_path, 'w+') 34 | file.close() 35 | 36 | @staticmethod 37 | def validate_folder(folder_path): 38 | """ 39 | Same as validate_file but for folders 40 | :param folder_path: Path to the folder 41 | :return: None 42 | """ 43 | if not os.path.isdir(folder_path): 44 | try: 45 | os.makedirs(os.path.abspath(folder_path)) 46 | except OSError as exc: # Guard against race condition 47 | if exc.errno != errno.EEXIST: 48 | raise 49 | 50 | @staticmethod 51 | def is_folder(folder_path): 52 | return os.path.isdir(folder_path) 53 | 54 | @staticmethod 55 | def is_file(file_path): 56 | return os.path.isfile(file_path) 57 | 58 | @staticmethod 59 | def clear_folder(folder_path): 60 | """ 61 | Clears all files at specified folder path 62 | :param folder_path: 63 | :return: None 64 | """ 65 | for file_ in os.listdir(folder_path): 66 | file_path = os.path.join(folder_path, file_) 67 | try: 68 | if os.path.isfile(file_path): 69 | os.unlink(file_path) 70 | # elif os.path.isdir(file_path): shutil.rmtree(file_path) 71 | except Exception as e: 72 | print(e) 73 | 74 | @staticmethod 75 | def dump_object(obj, file_path): 76 | """ 77 | Dumps an object into the a file at the path provided 78 | :param obj: Object that needs to be dumped 79 | :param file_path: Path of target pickle file 80 | :return: 81 | """ 82 | FileUtil.validate_file(file_path) 83 | with open(file_path, "wb") as file: 84 | pickle.dump(obj, file) 85 | 86 | 87 | @staticmethod 88 | def load_object(file_path): 89 | 90 | obj = None 91 | with open(file_path, "rb") as file: 92 | obj = pickle.load(file) 93 | 94 | return obj 95 | 96 | 97 | -------------------------------------------------------------------------------- /models/Matsushita_test.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | import os,sys 6 | par_dir = os.path.join(os.pardir, os.pardir) 7 | curr_path = os.path.abspath(__file__) 8 | root_path = os.path.abspath(os.path.join(curr_path, par_dir)) 9 | 10 | sys.path.append(str(root_path)) 11 | 12 | import unittest 13 | import torch 14 | from torch import nn 15 | from unittest import TestCase 16 | from torch.autograd import Variable 17 | 18 | from models.Matsushita import MatsushitaTransform,NormalizedMatsushita 19 | from torch.autograd import gradcheck 20 | 21 | from models.NormalizedSoftplus import NSoftPlus 22 | from models.LeastSquare import LeastSquare 23 | 24 | class MatsushitaModuleTest(TestCase): 25 | 26 | def test_forward_prop(self): 27 | m = MatsushitaTransform() 28 | v = Variable(torch.FloatTensor([1 , 2])) 29 | o = m.forward(v) 30 | ground_truth = torch.FloatTensor([1.2071067811865475 , 2.118033988749895]) 31 | self.assertTrue(torch.equal(ground_truth,o.data)) 32 | 33 | def test_forward_prop_least_square(self): 34 | m = LeastSquare() 35 | v = Variable(torch.FloatTensor([-2 , 1, 0.5, -1, 2])) 36 | o = m.forward(v) 37 | ground_truth = torch.FloatTensor([-1, 3, 1.25, -1, 7]) 38 | self.assertTrue(torch.equal(ground_truth,o.data)) 39 | 40 | def test_backward_prop_least_square(self): 41 | m = LeastSquare() 42 | v = Variable(torch.FloatTensor([-2 , 1, 0.5, -1, 2]), requires_grad=True) 43 | o = m.forward(v) 44 | l = o.sum() 45 | l.backward() 46 | ground_truth = torch.FloatTensor([0, 4, 3, 0, 4]) 47 | self.assertTrue(torch.equal(ground_truth, v.grad.data)) 48 | 49 | def test_scala_backward(self): 50 | m = MatsushitaTransform() 51 | criterion = torch.nn.MSELoss() 52 | y = Variable(torch.FloatTensor([1.118033988749895]), requires_grad=False) 53 | v = Variable(torch.FloatTensor([2]), requires_grad=True) 54 | o = m.forward(v) 55 | loss = criterion(o, y) 56 | loss.backward() 57 | self.assertEqual(v.grad[0], 1.8944) 58 | # 1.8944 59 | 60 | def test_backward_prop(self): 61 | input = (Variable(torch.randn(20, 10).double(), requires_grad=True),) 62 | test = gradcheck(MatsushitaTransform(), input, eps=1e-6, atol=1e-4) 63 | self.assertTrue(test) 64 | 65 | def test_backward_prop_normalized(self): 66 | input = (Variable(torch.randn(20, 10).double(), requires_grad=True),) 67 | test = gradcheck(NormalizedMatsushita(), input, eps=1e-6, atol=1e-4) 68 | self.assertTrue(test) 69 | 70 | def test_leastSquare(self): 71 | input = (Variable(torch.randn(20, 10).double(), requires_grad=True),) 72 | test = gradcheck(LeastSquare(), input, eps=1e-6, atol=1e-4) 73 | self.assertTrue(test, 'least square failed') 74 | 75 | def test_NSoftPlus(self): 76 | input = (Variable(torch.randn(20, 10).double(), requires_grad=True),) 77 | test = gradcheck(NSoftPlus(), input, eps=1e-6, atol=1e-4) 78 | self.assertTrue(test) 79 | 80 | if __name__ == '__main__': 81 | unittest.main() -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import torch 6 | import torch.nn as nn 7 | from models.Matsushita import MatsushitaTransform,MatsushitaTransformOne,NormalizedMatsushita 8 | 9 | class MLP_G(nn.Module): 10 | def __init__(self, isize, nz, nc, ngf, ngpu, hidden_activation='', mu=0.5, last_layer='none'): 11 | super(MLP_G, self).__init__() 12 | self.ngpu = ngpu 13 | if hidden_activation == 'matsu': 14 | first_activation = MatsushitaTransform() 15 | second_activation = MatsushitaTransform() 16 | third_activation = MatsushitaTransform() 17 | elif hidden_activation == 'matsu1': 18 | first_activation = MatsushitaTransformOne() 19 | second_activation = MatsushitaTransformOne() 20 | third_activation = MatsushitaTransformOne() 21 | elif hidden_activation == 'elu': 22 | first_activation = nn.ELU(alpha=mu) 23 | second_activation = nn.ELU(alpha=mu) 24 | third_activation = nn.ELU(alpha=mu) 25 | elif hidden_activation == 'murelu': 26 | first_activation = NormalizedMatsushita(mu) 27 | second_activation = NormalizedMatsushita(mu) 28 | third_activation = NormalizedMatsushita(mu) 29 | else: 30 | first_activation = nn.ReLU(False) 31 | second_activation = nn.ReLU(False) 32 | third_activation = nn.ReLU(False) 33 | main = nn.Sequential( 34 | # Z goes into a linear of size: ngf 35 | nn.Linear(nz, ngf), 36 | first_activation, 37 | nn.Linear(ngf, ngf), 38 | second_activation, 39 | nn.Linear(ngf, ngf), 40 | third_activation, 41 | nn.Linear(ngf, nc * isize * isize), 42 | ) 43 | 44 | if last_layer == 'tanh': 45 | main.add_module('final.{0}.tanh'.format(nc), 46 | nn.Tanh()) 47 | elif last_layer == 'sigmoid': 48 | main.add_module('final.{0}.sigmoid'.format(nc), 49 | nn.Sigmoid()) 50 | 51 | self.main = main 52 | self.nc = nc 53 | self.isize = isize 54 | self.nz = nz 55 | 56 | def forward(self, input): 57 | input = input.view(input.size(0), input.size(1)) 58 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 59 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 60 | else: 61 | output = self.main(input) 62 | return output.view(output.size(0), self.nc, self.isize, self.isize) 63 | 64 | 65 | class MLP_D(nn.Module): 66 | def __init__(self, isize, nz, nc, ndf, ngpu): 67 | super(MLP_D, self).__init__() 68 | self.ngpu = ngpu 69 | 70 | main = nn.Sequential( 71 | # Z goes into a linear of size: ndf 72 | nn.Linear(nc * isize * isize, ndf), 73 | nn.ReLU(False), 74 | nn.Linear(ndf, ndf), 75 | nn.ReLU(False), 76 | nn.Linear(ndf, ndf), 77 | nn.ReLU(False), 78 | nn.Linear(ndf, 1), 79 | ) 80 | self.main = main 81 | self.nc = nc 82 | self.isize = isize 83 | self.nz = nz 84 | 85 | def forward(self, input): 86 | input = input.view(input.size(0), 87 | input.size(1) * input.size(2) * input.size(3)) 88 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 89 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 90 | else: 91 | output = self.main(input) 92 | #output = output.mean(0) 93 | return output.view(input.size(0)) 94 | -------------------------------------------------------------------------------- /models/gan.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import torch 6 | import torch.nn as nn 7 | from models.Matsushita import NormalizedMatsushita,MatsushitaLinkFunc 8 | from models.NormalizedSoftplus import NSoftPlus 9 | from models.LeastSquare import LeastSquare 10 | 11 | class GAN_G(nn.Module): 12 | def __init__(self, isize, nz, nc, ngf, ngpu, hidden_activation, mu, last_layer='sigmoid'): 13 | super(GAN_G, self).__init__() 14 | self.ngpu = ngpu 15 | if hidden_activation == 'elu': 16 | first_activation = nn.ELU(mu) 17 | second_activation = nn.ELU(mu) 18 | elif hidden_activation == 'murelu': 19 | first_activation = NormalizedMatsushita(mu) 20 | second_activation = NormalizedMatsushita(mu) 21 | elif hidden_activation == 'ls': 22 | first_activation = LeastSquare() 23 | second_activation = LeastSquare() 24 | elif hidden_activation == 'sp': 25 | first_activation = NSoftPlus() 26 | second_activation = NSoftPlus() 27 | else: 28 | first_activation = nn.ReLU(False) 29 | second_activation = nn.ReLU(False) 30 | 31 | main = nn.Sequential( 32 | # Z goes into a linear of size: ngf 33 | nn.Linear(nz, ngf), 34 | nn.BatchNorm1d(ngf), 35 | first_activation, 36 | nn.Linear(ngf, ngf), 37 | nn.BatchNorm1d(ngf), 38 | second_activation, 39 | nn.Linear(ngf, nc * isize * isize) 40 | ) 41 | if last_layer == 'sigmoid': 42 | main.add_module('top_sigmoid', torch.nn.Sigmoid()) 43 | elif last_layer == 'tanh': 44 | main.add_module('top_tanh',torch.nn.Tanh()) 45 | 46 | self.main = main 47 | self.nc = nc 48 | self.isize = isize 49 | self.nz = nz 50 | 51 | def forward(self, input): 52 | input = input.view(input.size(0), input.size(1)) 53 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 54 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 55 | else: 56 | output = self.main(input) 57 | return output.view(output.size(0), self.nc, self.isize, self.isize) 58 | 59 | 60 | class GAN_D(nn.Module): 61 | def __init__(self, isize, nz, nc, ndf, ngpu,hidden_activation = 'relu', last_layer='', alpha=1.0): 62 | super(GAN_D, self).__init__() 63 | self.ngpu = ngpu 64 | 65 | if hidden_activation == 'elu': 66 | first_activation = nn.ELU(alpha=alpha) 67 | second_activation = nn.ELU(alpha=alpha) 68 | else: 69 | first_activation = nn.ReLU(False) 70 | second_activation = nn.ReLU(False) 71 | 72 | main = nn.Sequential( 73 | # Z goes into a linear of size: ndf 74 | nn.Linear(nc * isize * isize, ndf), 75 | first_activation, 76 | nn.Linear(ndf, ndf), 77 | second_activation, 78 | nn.Linear(ndf, 1) 79 | ) 80 | if last_layer == 'sigmoid': 81 | main.add_module('top_sigmoid', torch.nn.Sigmoid()) 82 | elif last_layer == 'tanh': 83 | main.add_module('top_tanh',torch.nn.Tanh()) 84 | elif last_layer == 'matsu': 85 | main.add_module('final.{0}.Matsushita'.format(nc), 86 | MatsushitaLinkFunc()) 87 | 88 | self.main = main 89 | self.nc = nc 90 | self.isize = isize 91 | self.nz = nz 92 | 93 | def forward(self, input): 94 | input = input.view(input.size(0), 95 | input.size(1) * input.size(2) * input.size(3)) 96 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 97 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 98 | else: 99 | output = self.main(input) 100 | return output.view(input.size(0)) -------------------------------------------------------------------------------- /kde.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import random 4 | from scipy.stats import gaussian_kde 5 | import torchvision.datasets as dset 6 | import torchvision.transforms as transforms 7 | import torchvision.utils as vutils 8 | import torch 9 | import numpy as np 10 | from sklearn.neighbors import KernelDensity 11 | from sklearn.model_selection import GridSearchCV 12 | from fileUtil import FileUtil 13 | 14 | def convert_to_ndarrays(dataset): 15 | return np.stack([vec[0].numpy().flatten() for vec in dataset]) 16 | 17 | 18 | def fit_kde(X, bandwidth): 19 | kde = KernelDensity(bandwidth=bandwidth) 20 | kde.fit(X=X) 21 | return kde 22 | 23 | 24 | def cal_logprob(kde, data): 25 | #logprob_vec = kde.score_samples(data) 26 | #mean_logprob = np.mean(logprob_vec) 27 | #max_p = np.max(logprob_vec) 28 | #normalized_logp = max_p + np.log(np.mean(np.exp(logprob_vec - max_p))) - (original_data_size - 1) * np.log(sigma * np.sqrt(np.pi * 2)) 29 | return kde.score(data) / data.shape[0] 30 | 31 | 32 | def search_bandwidth(val_data, cvJobs): 33 | data = convert_to_ndarrays(val_data) 34 | params = {'bandwidth': np.logspace(-1, 1, 20)} 35 | grid = GridSearchCV(KernelDensity(), params, n_jobs=cvJobs) 36 | grid.fit(data) 37 | print("best bandwidth: {0}".format(grid.best_estimator_.bandwidth)) 38 | return grid.best_estimator_.bandwidth 39 | 40 | def main(): 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--dataset', required=True, help='cifar10 | lsun | mnist') 43 | parser.add_argument('--dataroot', required=True, help='path to dataset') 44 | parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') 45 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 46 | parser.add_argument('--cvJobs', type=int, help='number of jobs for cross validation', default=4) 47 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') 48 | parser.add_argument('--modelPath', default='', help="path to the kernel density estimation model.") 49 | parser.add_argument('--task', default='hyper', help="hyper | train") 50 | parser.add_argument('--normalizeImages', type=bool, default=True) 51 | opt = parser.parse_args() 52 | print(opt) 53 | 54 | 55 | opt.manualSeed = random.randint(1, 10000) # fix seed 56 | nc = 3 # number of channels 57 | if opt.dataset == 'lsun': 58 | #3x256x341 59 | if opt.normalizeImages: 60 | transform_op = transforms.Compose([ 61 | transforms.Scale(opt.imageSize), 62 | transforms.CenterCrop(opt.imageSize), 63 | transforms.ToTensor(), 64 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 65 | ]) 66 | else: 67 | transform_op = transforms.Compose([ 68 | transforms.Scale(opt.imageSize), 69 | transforms.CenterCrop(opt.imageSize), 70 | transforms.ToTensor(), 71 | ]) 72 | dataset = dset.LSUN(db_path=opt.dataroot, classes=['tower_train'], 73 | transform=transform_op) 74 | val_dataset = [dataset[i] for i in range(50000, 60000)] 75 | test_dataset = dset.LSUN(db_path=opt.dataroot, classes=['tower_val'], 76 | transform=transform_op) 77 | elif opt.dataset == 'cifar10': 78 | dataset = dset.CIFAR10(root=opt.dataroot, download=True, train=True, 79 | transform=transforms.Compose([ 80 | transforms.Scale(opt.imageSize), 81 | transforms.ToTensor(), 82 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 83 | ])) 84 | 85 | val_dataset = [dataset[i] for i in range(10000, 20000)] 86 | test_dataset = dset.CIFAR10(root=opt.dataroot, download=True, train=False, 87 | transform=transforms.Compose([ 88 | transforms.Scale(opt.imageSize), 89 | transforms.ToTensor(), 90 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 91 | ])) 92 | elif opt.dataset == 'mnist': 93 | dataset = dset.MNIST(root=opt.dataroot, download=True, train=True, 94 | transform=transforms.Compose([ 95 | transforms.Scale(opt.imageSize), 96 | transforms.ToTensor(), 97 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 98 | ])) 99 | val_dataset = [dataset[i] for i in range(50000,60000)] 100 | 101 | test_dataset = dset.MNIST(root=opt.dataroot, download=True,train=False, 102 | transform=transforms.Compose([ 103 | transforms.Scale(opt.imageSize), 104 | transforms.ToTensor(), 105 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 106 | ])) 107 | nc = 1 108 | assert dataset 109 | assert val_dataset 110 | assert test_dataset 111 | 112 | 113 | if opt.task == 'hyper': 114 | search_bandwidth(val_dataset, opt.cvJobs) 115 | else: 116 | train_set = convert_to_ndarrays(dataset) 117 | print('max value: {0} , min value: {1}'.format(np.max(train_set), np.min(train_set))) 118 | if opt.imageSize == 32: 119 | b_width = 0.1 120 | else: 121 | b_width = 0.12742749857 122 | kde = fit_kde(train_set, bandwidth=b_width) 123 | mean_logprob = cal_logprob(kde, convert_to_ndarrays(test_dataset)) 124 | print('mean log probability : {0}'.format(mean_logprob)) 125 | # MNIST, size 64, bandwidth 0.206913808111 126 | # MNIST size 32 unnormalized 0.1 logprob 880.783584576 127 | # MNIST, size 28, bandwidth 0.263665089873 unnormalized 0.12742749857 logprob 526.829087276 128 | # CIFAR10, 32, bandwidth 0.263665089873 129 | 130 | if __name__ == '__main__': 131 | main() -------------------------------------------------------------------------------- /plot_results.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | import csv 4 | import numpy as np 5 | from os import listdir 6 | from os.path import isfile, join, isdir 7 | import argparse 8 | 9 | def read(file_path): 10 | with open(file_path, 'r') as csvfile: 11 | reader = csv.reader(csvfile, delimiter=',') 12 | log_probs = [] 13 | for row in reader: 14 | log_probs.append(float(row[1])) 15 | 16 | assert len(log_probs) == 11, 'length %s in %s ' % (len(log_probs), file_path) 17 | return log_probs 18 | 19 | def read_files(file_dir): 20 | file_paths = [join(file_dir, f) for f in listdir(file_dir) if isfile(join(file_dir, f))] 21 | log_prob_matrix = [read(file) for file in file_paths] 22 | std_dev = np.std(log_prob_matrix, axis=0) 23 | mean = np.mean(log_prob_matrix, axis=0) 24 | print(file_dir) 25 | print('mean: {0}'.format(mean)) 26 | print('std: {0}'.format(std_dev)) 27 | return mean, std_dev 28 | 29 | def plot_results(main_dir, target_dir): 30 | mean_std_tuples = [read_files(join(main_dir, folder)) for folder in listdir(main_dir) if isdir(join(main_dir, folder))] 31 | mean, std_dev = zip(*mean_std_tuples) 32 | mus = range(0,11) 33 | plt.gca().set_color_cycle(['black', 'orange', 'deeppink', 'green', 'blue']) 34 | fmts = ['>-', 'o-', 'x-', 'D-', 'p-'] 35 | exps = [folder for folder in listdir(main_dir) if isdir(join(main_dir, folder))] 36 | for i in range(len(mean)): 37 | plt.errorbar(mus, mean[i], yerr=std_dev[i], fmt=fmts[i]) 38 | 39 | plt.legend(exps, loc='lower right') 40 | plt.xlim((-0.2,10.2)) 41 | #plt.axis((0,11, 0.1, 0.8)) 42 | plt.grid(True) 43 | 44 | x_ticks = ['0', '0.1', '0.2', '0.3', '0.4', '0.5', '0.6', '0.7', '0.8', '0.9','1'] 45 | plt.xticks(mus, x_ticks) 46 | 47 | plt.xlabel('mu') 48 | plt.ylabel('Log Probability') 49 | #plt.show() 50 | plt.savefig(join(target_dir, "mu_relu.png"), bbox_inches='tight') 51 | 52 | def barplot(target_dir): 53 | softplus_gan_dcgan = [523.749325889, 612.327717494, 582.548003606] 54 | softplus_wgan_dcgan = [511.132247894, 527.427932984, 442.807643129] 55 | softplus_gan_mlp = [120.218675105, 118.340705537, 144.120714021] 56 | softplus_wgan_mlp = [453.721877243, 501.204041358, 444.071567208] 57 | leastSquare_gan_dcgan = [632.135430781, 662.210057543, 669.320158872] 58 | leastSquare_wgan_dcgan = [528.636335214, 560.550922099, 544.230405069] 59 | leastSquare_gan_mlp = [231.680103473, 163.158618896, 197.583367488] 60 | leastSquare_wgan_mlp = [423.829548358, 507.111113457, 484.804082098] 61 | 62 | N=4 63 | ind = np.arange(N) # the x locations for the groups 64 | width = 0.2 # the width of the bars 65 | 66 | fig, ax = plt.subplots() 67 | 68 | softplus_mean = [np.mean(softplus_gan_dcgan), np.mean(softplus_wgan_dcgan), np.mean(softplus_gan_mlp), np.mean(softplus_wgan_mlp)] 69 | softplus_std = [np.std(softplus_gan_dcgan), np.std(softplus_wgan_dcgan), np.std(softplus_gan_mlp), np.std(softplus_wgan_mlp)] 70 | rects_softplus = ax.bar(ind, softplus_mean, width, color='lightblue', yerr=softplus_std) 71 | 72 | ls_mean = [np.mean(leastSquare_gan_dcgan), np.mean(leastSquare_wgan_dcgan), np.mean(leastSquare_gan_mlp), 73 | np.mean(leastSquare_wgan_mlp)] 74 | ls_std = [np.std(leastSquare_gan_dcgan), np.std(leastSquare_wgan_dcgan), np.std(leastSquare_gan_mlp), 75 | np.std(leastSquare_wgan_mlp)] 76 | rects_ls = ax.bar(ind + width, ls_mean, width, color='lightgreen', yerr=ls_std) 77 | 78 | relu_mean = [681.01717885, 542.51281229, 457.24837393, 413.81168954] 79 | relu_std = [25.93946989, 21.66889571, 35.52057633, 125.71019851] 80 | rects_relu = ax.bar(ind + 2*width, relu_mean, width, color='yellow', yerr=relu_std) 81 | 82 | 83 | ax.set_ylabel('Log Probability') 84 | ax.set_xticks(ind + 1.5* width) 85 | ax.set_xticklabels(('gan_dcgan', 'wgan_dcgan', 'gan_mlp', 'wgan_mlp')) 86 | 87 | ax.legend((rects_softplus[0], rects_ls[0],rects_relu[0]), ('Softplus', 'LeastSquare', 'ReLU')) 88 | 89 | # plt.show() 90 | plt.savefig(join(target_dir, "sp_ls.png"), bbox_inches='tight') 91 | 92 | 93 | def autolabel(rects): 94 | """ 95 | Attach a text label above each bar displaying its height 96 | """ 97 | for rect in rects: 98 | height = rect.get_height() 99 | ax.text(rect.get_x() + rect.get_width()/3., 1.05*height, 100 | '%d' % int(height), 101 | ha='center', va='bottom') 102 | 103 | autolabel(rects_softplus) 104 | autolabel(rects_ls) 105 | autolabel(rects_relu) 106 | 107 | def discriminator_barplot(target_dir): 108 | matsu_gan_dcgan = [676.460141029,705.85256067,700.876242004] 109 | matsu_wgan_dcgan = [564.063737896,524.636358032,543.634910961] 110 | matsu_gan_mlp = [461.349531434,519.603921001,513.949371208] 111 | matsu_wgan_mlp = [360.227033106,392.627267256,361.221227421] 112 | 113 | N=4 114 | ind = np.arange(N) # the x locations for the groups 115 | width = 0.35 # the width of the bars 116 | 117 | fig, ax = plt.subplots() 118 | 119 | matsu_mean = [np.mean(matsu_gan_dcgan), np.mean(matsu_wgan_dcgan), np.mean(matsu_gan_mlp), np.mean(matsu_wgan_mlp)] 120 | matsu_std = [np.std(matsu_gan_dcgan), np.std(matsu_wgan_dcgan), np.std(matsu_gan_mlp), np.std(matsu_wgan_mlp)] 121 | rects_matsu = ax.bar(ind, matsu_mean, width, color='lightblue', yerr=matsu_std) 122 | 123 | relu_mean = [681.01717885, 542.51281229, 457.24837393, 413.81168954] 124 | relu_std = [25.93946989, 21.66889571, 35.52057633, 125.71019851] 125 | rects_relu = ax.bar(ind + width, relu_mean, width, color='lightgreen', yerr=relu_std) 126 | 127 | 128 | ax.set_ylabel('Log Probability') 129 | ax.set_xticks(ind + width) 130 | ax.set_xticklabels(('gan_dcgan', 'wgan_dcgan', 'gan_mlp', 'wgan_mlp')) 131 | 132 | ax.legend((rects_matsu[0], rects_relu[0]), ('Matsushita', 'Standard')) 133 | 134 | # plt.show() 135 | plt.savefig(join(target_dir, "discriminator_last_layer.png"), bbox_inches='tight') 136 | 137 | 138 | def autolabel(rects): 139 | """ 140 | Attach a text label above each bar displaying its height 141 | """ 142 | for rect in rects: 143 | height = rect.get_height() 144 | ax.text(rect.get_x() + rect.get_width()/2., 1.05*height, 145 | '%d' % int(height), 146 | ha='center', va='bottom') 147 | 148 | autolabel(rects_matsu) 149 | autolabel(rects_relu) 150 | 151 | def main(): 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument('--result_dir', required=True, help='path to results') 154 | parser.add_argument('--target_dir', required=True, help='path to target folder') 155 | opt = parser.parse_args() 156 | plot_results(opt.result_dir, opt.target_dir) 157 | barplot(opt.target_dir) 158 | discriminator_barplot(opt.target_dir) 159 | 160 | if __name__ == '__main__': 161 | main() -------------------------------------------------------------------------------- /models/dcgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | from models.Matsushita import MatsushitaTransform,NormalizedMatsushita,MatsushitaLinkFunc 5 | from models.NormalizedSoftplus import NSoftPlus 6 | from models.LeastSquare import LeastSquare 7 | 8 | class DCGAN_D(nn.Module): 9 | def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0, last_layer=''): 10 | super(DCGAN_D, self).__init__() 11 | self.ngpu = ngpu 12 | assert isize % 16 == 0, "isize has to be a multiple of 16" 13 | 14 | main = nn.Sequential() 15 | # input is nc x isize x isize 16 | main.add_module('initial.conv.{0}-{1}'.format(nc, ndf), 17 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)) 18 | main.add_module('initial.relu.{0}'.format(ndf), 19 | nn.LeakyReLU(0.2, inplace=True)) 20 | csize, cndf = isize / 2, ndf 21 | 22 | # Extra layers 23 | for t in range(n_extra_layers): 24 | main.add_module('extra-layers-{0}.{1}.conv'.format(t, cndf), 25 | nn.Conv2d(cndf, cndf, 3, 1, 1, bias=False)) 26 | main.add_module('extra-layers-{0}.{1}.batchnorm'.format(t, cndf), 27 | nn.BatchNorm2d(cndf)) 28 | main.add_module('extra-layers-{0}.{1}.relu'.format(t, cndf), 29 | nn.LeakyReLU(0.2, inplace=True)) 30 | 31 | while csize > 4: 32 | in_feat = cndf 33 | out_feat = cndf * 2 34 | main.add_module('pyramid.{0}-{1}.conv'.format(in_feat, out_feat), 35 | nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False)) 36 | main.add_module('pyramid.{0}.batchnorm'.format(out_feat), 37 | nn.BatchNorm2d(out_feat)) 38 | main.add_module('pyramid.{0}.relu'.format(out_feat), 39 | nn.LeakyReLU(0.2, inplace=True)) 40 | cndf = cndf * 2 41 | csize = csize / 2 42 | 43 | # state size. K x 4 x 4 44 | main.add_module('final.{0}-{1}.conv'.format(cndf, 1), 45 | nn.Conv2d(cndf, 1, 4, 1, 0, bias=False)) 46 | 47 | if last_layer == 'sigmoid': 48 | main.add_module('final.{0}.sigmoid'.format(nc), 49 | nn.Sigmoid()) 50 | elif last_layer == 'matsu': 51 | main.add_module('final.{0}.Matsushita'.format(nc), 52 | MatsushitaLinkFunc()) 53 | self.main = main 54 | 55 | 56 | def forward(self, input): 57 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 58 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 59 | else: 60 | output = self.main(input) 61 | 62 | return output.view(-1, 1) 63 | 64 | class DCGAN_G(nn.Module): 65 | def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0, hidden_activation = '', mu = 0, last_layer = 'tanh'): 66 | super(DCGAN_G, self).__init__() 67 | self.ngpu = ngpu 68 | assert isize % 16 == 0, "isize has to be a multiple of 16" 69 | 70 | cngf, tisize = ngf//2, 4 71 | while tisize != isize: 72 | cngf = cngf * 2 73 | tisize = tisize * 2 74 | 75 | main = nn.Sequential() 76 | # input is Z, going into a convolution 77 | main.add_module('initial.{0}-{1}.convt'.format(nz, cngf), 78 | nn.ConvTranspose2d(nz, cngf, 4, 1, 0, bias=False)) 79 | main.add_module('initial.{0}.batchnorm'.format(cngf), 80 | nn.BatchNorm2d(cngf)) 81 | if hidden_activation == 'murelu': 82 | main.add_module('initial.{0}.matsushita({1})'.format(cngf, mu), 83 | NormalizedMatsushita(mu)) 84 | elif hidden_activation == 'ls': 85 | main.add_module('initial.{0}.LeastSquare'.format(cngf // 2), 86 | LeastSquare()) 87 | elif hidden_activation == 'sp': 88 | main.add_module('initial.{0}.SoftPlus'.format(cngf // 2), 89 | NSoftPlus()) 90 | else: 91 | main.add_module('initial.{0}.relu'.format(cngf), 92 | nn.ReLU(True)) 93 | 94 | csize, cndf = 4, cngf 95 | while csize < isize//2: 96 | main.add_module('pyramid.{0}-{1}.convt'.format(cngf, cngf//2), 97 | nn.ConvTranspose2d(cngf, cngf//2, 4, 2, 1, bias=False)) 98 | main.add_module('pyramid.{0}.batchnorm'.format(cngf//2), 99 | nn.BatchNorm2d(cngf//2)) 100 | if hidden_activation == 'murelu': 101 | main.add_module('pyramid.{0}.matsushita({1})'.format(cngf//2, mu), 102 | NormalizedMatsushita(mu)) 103 | elif hidden_activation == 'ls': 104 | main.add_module('pyramid.{0}.LeastSquare'.format(cngf // 2), 105 | LeastSquare()) 106 | elif hidden_activation == 'sp': 107 | main.add_module('pyramid.{0}.SoftPlus'.format(cngf // 2), 108 | NSoftPlus()) 109 | else: 110 | main.add_module('pyramid.{0}.relu'.format(cngf//2), 111 | nn.ReLU(True)) 112 | cngf = cngf // 2 113 | csize = csize * 2 114 | 115 | # Extra layers 116 | for t in range(n_extra_layers): 117 | main.add_module('extra-layers-{0}.{1}.conv'.format(t, cngf), 118 | nn.Conv2d(cngf, cngf, 3, 1, 1, bias=False)) 119 | main.add_module('extra-layers-{0}.{1}.batchnorm'.format(t, cngf), 120 | nn.BatchNorm2d(cngf)) 121 | if hidden_activation == 'murelu': 122 | main.add_module('extra-layers-{0}.{1}.matsushita({2})'.format(t, cngf, mu), 123 | NormalizedMatsushita(mu)) 124 | elif hidden_activation == 'ls': 125 | main.add_module('extra-layers-{0}.{1}.LeastSquare'.format(t, cngf), LeastSquare()) 126 | elif hidden_activation == 'sp': 127 | main.add_module('extra-layers-{0}.{1}.Softplus'.format(t, cngf), NSoftPlus()) 128 | else: 129 | main.add_module('extra-layers-{0}.{1}.relu'.format(t, cngf), 130 | nn.ReLU(True)) 131 | 132 | main.add_module('final.{0}-{1}.convt'.format(cngf, nc), 133 | nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False)) 134 | 135 | if last_layer == 'sigmoid': 136 | main.add_module('final.{0}.sigmoid'.format(nc), 137 | nn.Sigmoid()) 138 | else: 139 | main.add_module('final.{0}.tanh'.format(nc), 140 | nn.Tanh()) 141 | 142 | self.main = main 143 | 144 | def forward(self, input): 145 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 146 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 147 | else: 148 | output = self.main(input) 149 | return output 150 | ############################################################################### 151 | class DCGAN_D_nobn(nn.Module): 152 | def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0): 153 | super(DCGAN_D_nobn, self).__init__() 154 | self.ngpu = ngpu 155 | assert isize % 16 == 0, "isize has to be a multiple of 16" 156 | 157 | main = nn.Sequential() 158 | # input is nc x isize x isize 159 | # input is nc x isize x isize 160 | main.add_module('initial.conv.{0}-{1}'.format(nc, ndf), 161 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)) 162 | main.add_module('initial.relu.{0}'.format(ndf), 163 | nn.LeakyReLU(0.2, inplace=True)) 164 | csize, cndf = isize / 2, ndf 165 | 166 | # Extra layers 167 | for t in range(n_extra_layers): 168 | main.add_module('extra-layers-{0}.{1}.conv'.format(t, cndf), 169 | nn.Conv2d(cndf, cndf, 3, 1, 1, bias=False)) 170 | main.add_module('extra-layers-{0}.{1}.relu'.format(t, cndf), 171 | nn.LeakyReLU(0.2, inplace=True)) 172 | 173 | while csize > 4: 174 | in_feat = cndf 175 | out_feat = cndf * 2 176 | main.add_module('pyramid.{0}-{1}.conv'.format(in_feat, out_feat), 177 | nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False)) 178 | main.add_module('pyramid.{0}.relu'.format(out_feat), 179 | nn.LeakyReLU(0.2, inplace=True)) 180 | cndf = cndf * 2 181 | csize = csize / 2 182 | 183 | # state size. K x 4 x 4 184 | main.add_module('final.{0}-{1}.conv'.format(cndf, 1), 185 | nn.Conv2d(cndf, 1, 4, 1, 0, bias=False)) 186 | self.main = main 187 | 188 | 189 | def forward(self, input): 190 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 191 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 192 | else: 193 | output = self.main(input) 194 | 195 | output = output.mean(0) 196 | return output.view(1) 197 | 198 | 199 | class DCGAN_G_nobn(nn.Module): 200 | def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0, matsushita_layer = ''): 201 | super(DCGAN_G_nobn, self).__init__() 202 | self.ngpu = ngpu 203 | assert isize % 16 == 0, "isize has to be a multiple of 16" 204 | 205 | cngf, tisize = ngf//2, 4 206 | while tisize != isize: 207 | cngf = cngf * 2 208 | tisize = tisize * 2 209 | 210 | main = nn.Sequential() 211 | main.add_module('initial.{0}-{1}.convt'.format(nz, cngf), 212 | nn.ConvTranspose2d(nz, cngf, 4, 1, 0, bias=False)) 213 | main.add_module('initial.{0}.relu'.format(cngf), 214 | nn.ReLU(True)) 215 | 216 | csize, cndf = 4, cngf 217 | while csize < isize//2: 218 | main.add_module('pyramid.{0}-{1}.convt'.format(cngf, cngf//2), 219 | nn.ConvTranspose2d(cngf, cngf//2, 4, 2, 1, bias=False)) 220 | main.add_module('pyramid.{0}.relu'.format(cngf//2), 221 | nn.ReLU(True)) 222 | cngf = cngf // 2 223 | csize = csize * 2 224 | 225 | # Extra layers 226 | for t in range(n_extra_layers): 227 | main.add_module('extra-layers-{0}.{1}.conv'.format(t, cngf), 228 | nn.Conv2d(cngf, cngf, 3, 1, 1, bias=False)) 229 | main.add_module('extra-layers-{0}.{1}.relu'.format(t, cngf), 230 | nn.ReLU(True)) 231 | 232 | main.add_module('final.{0}-{1}.convt'.format(cngf, nc), 233 | nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False)) 234 | 235 | if matsushita_layer == '2nd_last': 236 | main.add_module('2nd_last.{0}.Matsushita'.format(nc), 237 | MatsushitaTransform()) 238 | main.add_module('final.{0}.tanh'.format(nc), 239 | nn.Tanh()) 240 | elif matsushita_layer == 'last': 241 | main.add_module('final.{0}.Matsushita'.format(nc), 242 | MatsushitaTransform()) 243 | else: 244 | main.add_module('final.{0}.tanh'.format(nc), 245 | nn.Tanh()) 246 | 247 | 248 | self.main = main 249 | 250 | def forward(self, input): 251 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 252 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 253 | else: 254 | output = self.main(input) 255 | return output 256 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim as optim 9 | import torch.utils.data 10 | import torchvision.datasets as dset 11 | import torchvision.transforms as transforms 12 | import torchvision.utils as vutils 13 | from torch.autograd import Variable 14 | import os 15 | 16 | import models.dcgan as dcgan 17 | import models.mlp as mlp 18 | from models import gan 19 | from fileUtil import FileUtil 20 | from kde import cal_logprob 21 | from kde import fit_kde 22 | from kde import convert_to_ndarrays 23 | import numpy as np 24 | import math 25 | import time 26 | import logging 27 | import csv 28 | import traceback 29 | 30 | 31 | def train(opt, log_file_path): 32 | if opt.experiment is None: 33 | opt.experiment = 'samples' 34 | os.system('mkdir {0}'.format(opt.experiment)) 35 | elif not os.path.exists(opt.experiment): 36 | os.system('mkdir {0}'.format(opt.experiment)) 37 | 38 | logger = logging.getLogger() 39 | for hdlr in logger.handlers[:]: # remove all old handlers 40 | logger.removeHandler(hdlr) 41 | logger.setLevel(logging.INFO) 42 | 43 | # create a file handler 44 | handler = logging.FileHandler(log_file_path) 45 | handler.setLevel(logging.INFO) 46 | 47 | # create a logging format 48 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 49 | handler.setFormatter(formatter) 50 | 51 | # add the handlers to the logger 52 | logger.addHandler(handler) 53 | 54 | stream_handler = logging.StreamHandler() 55 | stream_handler.setLevel(logging.INFO) 56 | stream_handler.setFormatter(formatter) 57 | logger.addHandler(stream_handler) 58 | 59 | logger.info(opt) 60 | 61 | #opt.manualSeed = random.randint(1, 10000) # fix seed 62 | logger.info("Random Seed: %s " % opt.manualSeed) 63 | random.seed(opt.manualSeed) 64 | torch.manual_seed(opt.manualSeed) 65 | if opt.cuda: 66 | torch.cuda.manual_seed_all(opt.manualSeed) 67 | 68 | cudnn.benchmark = True 69 | 70 | nc = int(opt.nc) 71 | 72 | if torch.cuda.is_available() and not opt.cuda: 73 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 74 | 75 | def eval_with_KDE(generator, original_test_set, random_test_noise): 76 | num_insts = random_test_noise.size()[0] 77 | num_batches = int(math.ceil(num_insts / float(opt.batchSize))) 78 | instances = [] 79 | for i in range(num_batches): 80 | fake_test_set = generator( 81 | Variable(random_test_noise[i * opt.batchSize: (i + 1) * opt.batchSize], volatile=True)) 82 | instances.extend([vec.flatten() for vec in fake_test_set.data.cpu().numpy()]) 83 | flattened_data = np.stack(instances) 84 | # print('input data for KDE is of shape {0} '.format(flattened_data.shape)) 85 | kde = fit_kde(flattened_data, bandwidth=opt.bandwidth) 86 | mean_logp = cal_logprob(kde, original_test_set) 87 | return mean_logp 88 | 89 | def xavier_init(param): 90 | size = param.data.size() 91 | in_dim = size[0] 92 | xavier_stddev = 1. / np.sqrt(in_dim / 2.) 93 | param.data = torch.randn(*size) * xavier_stddev 94 | 95 | sample_validation_without_replacement = True 96 | val_set = [] 97 | if opt.dataset == 'lsun': 98 | opt.normalizeImages = True 99 | #3x256x341 100 | if opt.normalizeImages: 101 | transform_op = transforms.Compose([ 102 | transforms.Scale(opt.imageSize), 103 | transforms.CenterCrop(opt.imageSize), 104 | transforms.ToTensor(), 105 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 106 | ]) 107 | else: 108 | transform_op = transforms.Compose([ 109 | transforms.Scale(opt.imageSize), 110 | transforms.CenterCrop(opt.imageSize), 111 | transforms.ToTensor(), 112 | ]) 113 | dataset = dset.LSUN(db_path=opt.dataroot, classes=['{0}_train'.format(opt.subset)], 114 | transform=transform_op) 115 | 116 | if opt.task == 'hyper': 117 | dataset = [dataset[i] for i in range(10000, 40000)] 118 | 119 | test_dataset = dset.LSUN(db_path=opt.dataroot, classes=['{0}_val'.format(opt.subset)], 120 | transform=transform_op) 121 | val_set = convert_to_ndarrays([dataset[i] for i in range(0, 1000)]) 122 | if sample_validation_without_replacement: 123 | dataset = [dataset[i] for i in range(1001, len(dataset))] 124 | nc = 3 125 | opt.bandwidth = 0.335981828628 126 | size_test_noise = 3000 127 | size_val_noise = 3000 128 | elif opt.dataset == 'cifar10': 129 | opt.normalizeImages = True 130 | dataset = dset.CIFAR10(root=opt.dataroot, download=True, 131 | transform=transforms.Compose([ 132 | transforms.Scale(opt.imageSize), 133 | transforms.ToTensor(), 134 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 135 | ])) 136 | 137 | test_dataset = dset.CIFAR10(root=opt.dataroot, download=True, train=False, 138 | transform=transforms.Compose([ 139 | transforms.Scale(opt.imageSize), 140 | transforms.ToTensor(), 141 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 142 | ])) 143 | print('number of images in test set %s ' % len(test_dataset)) 144 | val_set = convert_to_ndarrays([test_dataset[i] for i in range(0, 1000)]) 145 | test_dataset = [test_dataset[i] for i in range(1001, len(test_dataset))] 146 | if sample_validation_without_replacement: 147 | dataset = [dataset[i] for i in range(1001, len(dataset))] 148 | nc = 3 149 | size_test_noise = 6000 150 | size_val_noise = 6000 151 | if opt.imageSize == 32: 152 | opt.bandwidth = 0.263665089873 153 | else: 154 | opt.bandwidth = 0.335981828628 155 | elif opt.dataset == 'mnist': 156 | 157 | if opt.normalizeImages: 158 | img_transform = transforms.Compose([ 159 | transforms.Scale(opt.imageSize), 160 | transforms.ToTensor(), 161 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 162 | ]) 163 | else: 164 | img_transform = transforms.Compose([ 165 | transforms.Scale(opt.imageSize), 166 | transforms.ToTensor(), 167 | ]) 168 | 169 | dataset = dset.MNIST(root=opt.dataroot, download=True, 170 | transform=img_transform) 171 | test_dataset = dset.MNIST(root=opt.dataroot, download=True, train=False, 172 | transform=img_transform) 173 | nc = 1 174 | 175 | if opt.imageSize == 32: 176 | opt.bandwidth = 0.1 177 | else: 178 | opt.bandwidth = 0.12742749857 179 | val_set = convert_to_ndarrays([dataset[i] for i in range(50000, 51000)]) 180 | if sample_validation_without_replacement: 181 | dataset = [dataset[i] for i in range(0, len(dataset)) if i < 50000 or i > 51000] 182 | size_test_noise = 16000 183 | size_val_noise = 5000 184 | 185 | # if opt.A in ['mlp']: 186 | # logger.info('Apply experimental setting of F-GAN on MNIST.') 187 | # opt.nz = 100 188 | # opt.ndf = 240 189 | # opt.ngf = 1200 190 | # opt.init_z = 'uniform_one' 191 | # opt.adam = True 192 | # opt.lrD = 0.0002 193 | # opt.lrG = 0.0002 194 | # opt.beta1 = 0.5 195 | # opt.batchSize = 4096 196 | # opt.init_w = 'uniform' 197 | # opt.last_layer = 'sigmoid' 198 | # opt.clamp_lower = 0 199 | # opt.clamp_upper = 0 200 | 201 | assert dataset 202 | assert test_dataset 203 | assert len(val_set) > 0 204 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, 205 | shuffle=True, num_workers=int(opt.workers)) 206 | 207 | ngpu = 1 208 | if opt.gpu_id < 0: 209 | ngpu = int(opt.ngpu) 210 | nz = int(opt.nz) 211 | ngf = int(opt.ngf) 212 | ndf = int(opt.ndf) 213 | n_extra_layers = int(opt.n_extra_layers) 214 | 215 | # Load KDE model, if available 216 | 217 | def init_z(tensor): 218 | if opt.init_z == 'uniform_one': 219 | tensor.uniform_(-1, 1) 220 | elif opt.init_z == 'uniform_zero_one': 221 | tensor.uniform_(0, 1) 222 | else: 223 | tensor.normal_(0, 1) 224 | return tensor 225 | 226 | test_noise = None 227 | val_noise = None 228 | 229 | 230 | if opt.bandwidth != 0: 231 | test_noise = init_z(torch.FloatTensor(size_test_noise, nz, 1, 1)) 232 | val_noise = init_z(torch.FloatTensor(size_val_noise, nz, 1, 1)) 233 | test_set = convert_to_ndarrays(test_dataset) 234 | if opt.cuda: 235 | test_noise = test_noise.cuda() 236 | val_noise = val_noise.cuda() 237 | 238 | # custom weights initialization called on netG and netD 239 | def weights_init(m): 240 | classname = m.__class__.__name__ 241 | if classname.find('Conv') != -1: 242 | m.weight.data.normal_(0.0, 0.02) 243 | elif classname.find('BatchNorm') != -1: 244 | m.weight.data.normal_(1.0, 0.02) 245 | m.bias.data.fill_(0) 246 | 247 | # initialize generator 248 | 249 | if opt.A == 'wmlp': 250 | netG = mlp.MLP_G(opt.imageSize, nz, nc, ngf, ngpu, hidden_activation=opt.H, mu=opt.mu, last_layer=opt.last_layer) 251 | if opt.init_w == 'xavier': 252 | [xavier_init(param) for param in netG.parameters()] 253 | else: 254 | [param.data.uniform_(-0.05, 0.05) for param in netG.parameters()] 255 | elif opt.A == 'mlp': 256 | netG = gan.GAN_G(opt.imageSize, nz, nc, ngf, ngpu, hidden_activation=opt.H, mu=opt.mu, last_layer=opt.last_layer) 257 | if opt.init_w == 'xavier': 258 | [xavier_init(param) for param in netG.parameters()] 259 | else: 260 | [param.data.uniform_(-0.05, 0.05) for param in netG.parameters()] 261 | else: 262 | if opt.noBN: 263 | netG = dcgan.DCGAN_G_nobn(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers) 264 | else: 265 | last_layer = 'sigmoid' 266 | if opt.normalizeImages: 267 | last_layer = 'tanh' 268 | netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers, hidden_activation=opt.H, mu=opt.mu, last_layer=last_layer) 269 | netG.apply(weights_init) 270 | 271 | if opt.netG != '': # load checkpoint if needed 272 | netG.load_state_dict(torch.load(opt.netG)) 273 | logger.info(netG) 274 | 275 | # Initialize critic 276 | if opt.A == 'wmlp': 277 | netD = mlp.MLP_D(opt.imageSize, nz, nc, ndf, ngpu) 278 | if opt.init_w == 'xavier': 279 | [xavier_init(param) for param in netG.parameters()] 280 | else: 281 | [param.data.uniform_(-0.005, 0.005) for param in netG.parameters()] 282 | elif opt.A == 'mlp': 283 | netD = gan.GAN_D(opt.imageSize, nz, nc, ngf, ngpu,hidden_activation = opt.c_activation, last_layer=opt.critic_last_layer, alpha=opt.alpha) 284 | if opt.init_w == 'xavier': 285 | [xavier_init(param) for param in netG.parameters()] 286 | else: 287 | [param.data.uniform_(-0.005, 0.005) for param in netG.parameters()] 288 | else: 289 | netD = dcgan.DCGAN_D(opt.imageSize, nz, nc, ndf, ngpu, n_extra_layers, last_layer=opt.critic_last_layer) 290 | netD.apply(weights_init) 291 | 292 | if opt.netD != '': 293 | netD.load_state_dict(torch.load(opt.netD)) 294 | logger.info(netD) 295 | 296 | input = torch.FloatTensor(opt.batchSize, nc, opt.imageSize, opt.imageSize) 297 | noise = torch.FloatTensor(opt.batchSize, nz, 1, 1) 298 | num_images = 24 299 | fixed_noise = init_z(torch.FloatTensor(num_images, nz, 1, 1)) 300 | one = torch.FloatTensor([1]) 301 | mone = one * -1 302 | 303 | # for GAN 304 | label = torch.FloatTensor(opt.batchSize) 305 | real_label = 1 306 | fake_label = 0 307 | criterion = nn.BCELoss() 308 | 309 | if opt.cuda: 310 | netD.cuda() 311 | netG.cuda() 312 | input = input.cuda() 313 | criterion.cuda() 314 | label = label.cuda() 315 | one, mone = one.cuda(), mone.cuda() 316 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 317 | 318 | # setup optimizer 319 | if opt.adam: 320 | logger.info("Use ADAM") 321 | optimizerD = optim.Adam(netD.parameters(), lr=opt.lrD, betas=(opt.beta1, 0.999), weight_decay=opt.weightDecay) 322 | optimizerG = optim.Adam(netG.parameters(), lr=opt.lrG, betas=(opt.beta1, 0.999), weight_decay=opt.weightDecay) 323 | else: 324 | logger.info("Use RMSprop") 325 | optimizerD = optim.RMSprop(netD.parameters(), lr=opt.lrD) 326 | optimizerG = optim.RMSprop(netG.parameters(), lr=opt.lrG) 327 | 328 | def sample_image_compute_density(start_epoch, end_epoch): 329 | 330 | with open(os.path.join(opt.kde_result_dir, 'kde_results.csv'), 'w') as kde_file: 331 | best_logprob = -10000000 332 | for epoch in range(start_epoch, end_epoch): 333 | netG_model = '{0}/netG_epoch_{1}.pth'.format(opt.experiment, epoch) 334 | if os.path.exists(netG_model): 335 | netG.load_state_dict(torch.load(netG_model)) 336 | if opt.cuda: 337 | netG.cuda() 338 | fake = netG(Variable(fixed_noise, volatile=True)) 339 | if opt.normalizeImages: 340 | fake.data = fake.data.mul(0.5).add(0.5) 341 | vutils.save_image(fake.data, 342 | '{0}/fake_samples_epoch_{1}.png'.format(opt.kde_result_dir, epoch)) 343 | logprob_mean = eval_with_KDE(netG, test_set, test_noise) 344 | kde_file.write("{0}\t{1}".format(epoch, logprob_mean)) 345 | best_logprob = max(best_logprob, logprob_mean) 346 | return best_logprob 347 | 348 | best_logprob = 0 349 | label = Variable(label) 350 | noisev = Variable(noise) 351 | def eval_gan(netDiscriminator, netGenerator): 352 | gen_iterations = 0 353 | train_best_netG_model = '' 354 | train_best_logprob = 0 355 | num_epochs_img = max(3, opt.niter / 5) 356 | if opt.dataset == 'lsun': 357 | num_epochs_img = 1 358 | start_time = time.time() 359 | 360 | for epoch in range(opt.niter): 361 | loss_D = 0 362 | loss_G = 0 363 | for i, data in enumerate(dataloader, 0): 364 | ############################ 365 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 366 | ########################### 367 | # train with real 368 | for p in netDiscriminator.parameters(): # reset requires_grad 369 | p.requires_grad = True # they are set to False below in netG update 370 | for p in netGenerator.parameters(): # disable grad of generator 371 | p.requires_grad = False 372 | 373 | if opt.clamp_upper > 0: 374 | # clamp parameters to a cube 375 | for p in netDiscriminator.parameters(): 376 | p.data.clamp_(opt.clamp_lower, opt.clamp_upper) 377 | 378 | netDiscriminator.zero_grad() 379 | real_cpu, _ = data 380 | batch_size = real_cpu.size(0) 381 | if opt.cuda: 382 | real_cpu = real_cpu.cuda() 383 | input.resize_as_(real_cpu).copy_(real_cpu) 384 | inputv = Variable(input) 385 | 386 | inputv.data.resize_(real_cpu.size()).copy_(real_cpu) 387 | label.data.resize_(batch_size).fill_(real_label) 388 | 389 | output = netDiscriminator(inputv) 390 | #errD_real = criterion(output, label) 391 | errD_real = torch.mean(torch.neg(torch.log(output))) 392 | errD_real.backward() 393 | 394 | # train with fake 395 | noisev.data.resize_(batch_size, nz, 1, 1) 396 | init_z(noisev.data) # totally freeze netG 397 | fake = netGenerator(noisev) 398 | label.data.fill_(fake_label) 399 | output = netDiscriminator(fake.detach()) 400 | errD_fake = criterion(output, label) 401 | #errD_fake = torch.mean(torch.neg(torch.log(1 - torch.exp(torch.log(output))))) 402 | errD_fake.backward() 403 | errD = errD_real + errD_fake 404 | loss_D += errD.data.sum() 405 | optimizerD.step() 406 | 407 | ############################ 408 | # (2) Update G network: maximize log(D(G(z))) 409 | ########################### 410 | for p in netDiscriminator.parameters(): 411 | p.requires_grad = False # to avoid computation 412 | for p in netGenerator.parameters(): # reset requires_grad 413 | p.requires_grad = True 414 | netGenerator.zero_grad() 415 | label.data.fill_(real_label) # fake labels are real for generator cost 416 | init_z(noisev.data) # totally freeze netG 417 | fake = netGenerator(noisev) 418 | fake_output = netDiscriminator(fake) 419 | errG = criterion(fake_output, label) 420 | #errG = torch.mean(torch.neg(torch.log(fake_output))) 421 | errG.backward() 422 | loss_G += errG.data.sum() 423 | optimizerG.step() 424 | gen_iterations += 1 425 | 426 | 427 | if epoch % num_epochs_img == 0: 428 | if opt.normalizeImages: 429 | real_cpu = real_cpu.mul(0.5).add(0.5) 430 | 431 | vutils.save_image(real_cpu[0:num_images], '{0}/real_samples.png'.format(opt.experiment)) 432 | fake = netGenerator(Variable(fixed_noise, volatile=True)) 433 | if opt.normalizeImages: 434 | fake.data = fake.data.mul(0.5).add(0.5) 435 | vutils.save_image(fake.data, 436 | '{0}/fake_samples_{1}_epoch_{2}.png'.format(opt.experiment, gen_iterations, epoch)) 437 | 438 | # do checkpointing 439 | end_time = time.time() 440 | logger.info('[%d/%d] : Loss_D: %f Loss_G: %f, running time: %f' 441 | % (epoch, opt.niter, 442 | loss_D, loss_G, (end_time - start_time))) 443 | torch.save(netGenerator.state_dict(), '{0}/netG_epoch_{1}.pth'.format(opt.experiment, epoch)) 444 | torch.save(netDiscriminator.state_dict(), '{0}/netD_epoch_{1}.pth'.format(opt.experiment, epoch)) 445 | val_logprob_normalized = eval_with_KDE(netGenerator, val_set, val_noise) 446 | if train_best_netG_model == '' or val_logprob_normalized > train_best_logprob: 447 | train_best_logprob = val_logprob_normalized 448 | train_best_netG_model = '{0}/netG_epoch_{1}.pth'.format(opt.experiment, epoch) 449 | logger.info( 450 | 'The current best model is epoch {0} with mean log probability {1}.'.format( 451 | epoch, val_logprob_normalized)) 452 | 453 | return train_best_netG_model, train_best_logprob 454 | 455 | def eval(netDiscriminator, netGenerator): 456 | gen_iterations = 0 457 | train_best_netG_model = '' 458 | train_best_logprob = 0 459 | num_epochs_img = max(3, opt.niter/5) 460 | if opt.dataset == 'lsun': 461 | num_epochs_img = 1 462 | start_time = time.time() 463 | 464 | for epoch in range(opt.niter): 465 | data_iter = iter(dataloader) 466 | i = 0 467 | loss_D = 0 468 | loss_G = 0 469 | while i < len(dataloader): 470 | ############################ 471 | # (1) Update D network 472 | ########################### 473 | for p in netDiscriminator.parameters(): # reset requires_grad 474 | p.requires_grad = True # they are set to False below in netG update 475 | for p in netGenerator.parameters(): # disable grad of generator 476 | p.requires_grad = False 477 | 478 | Diters = 1 479 | if opt.D == 'wgan': 480 | # train the discriminator Diters times 481 | if opt.wganheuristics and (gen_iterations < 25 or gen_iterations % 500 == 0): 482 | Diters = 100 483 | else: 484 | Diters = opt.Diters 485 | j = 0 486 | while j < Diters and i < len(dataloader): 487 | j += 1 488 | 489 | if opt.clamp_upper > 0: 490 | # clamp parameters to a cube 491 | for p in netDiscriminator.parameters(): 492 | p.data.clamp_(opt.clamp_lower, opt.clamp_upper) 493 | 494 | data = data_iter.next() 495 | i += 1 496 | 497 | # train with real 498 | real_cpu, _ = data 499 | batch_size = real_cpu.size(0) 500 | netDiscriminator.zero_grad() 501 | 502 | if opt.cuda: 503 | real_cpu = real_cpu.cuda() 504 | input.resize_as_(real_cpu).copy_(real_cpu) 505 | inputv = Variable(input) 506 | 507 | errD_real = netDiscriminator(inputv) 508 | if opt.D == 'fgan': 509 | errD_real = torch.neg(torch.log(1 + torch.exp(torch.neg(errD_real)))) 510 | 511 | # train with fake 512 | init_z(noise.resize_(batch_size, nz, 1, 1)) 513 | noisev = Variable(noise) # totally freeze netG 514 | fake = Variable(netGenerator(noisev).data) 515 | errD_fake = netDiscriminator(fake) 516 | if opt.D == 'fgan': 517 | errD_fake = torch.log(1/(1 + torch.exp(torch.neg(errD_fake)))) 518 | errD_fake = torch.neg(torch.log(1 - torch.exp(errD_fake))) 519 | elif opt.D == 'kl': 520 | errD_fake = torch.exp(errD_fake - 1) 521 | loss_discriminator = torch.mean(errD_fake - errD_real) 522 | loss_discriminator.backward(one) 523 | optimizerD.step() 524 | #loss_D += errD_fake.data + errD_real.data 525 | loss_D += loss_discriminator.data 526 | 527 | 528 | ############################ 529 | # (2) Update G network 530 | ########################### 531 | for p in netDiscriminator.parameters(): 532 | p.requires_grad = False # to avoid computation 533 | for p in netGenerator.parameters(): # reset requires_grad 534 | p.requires_grad = True 535 | netGenerator.zero_grad() 536 | # in case our last batch was the tail batch of the dataloader, 537 | # make sure we feed a full batch of noise 538 | init_z(noise.resize_(batch_size, nz, 1, 1)) 539 | noisev = Variable(noise, volatile=False) 540 | fake = netGenerator(noisev) 541 | errG = netDiscriminator(fake) 542 | 543 | if opt.D == 'fgan': 544 | errG = torch.neg(torch.log(1 + torch.exp(torch.neg(errG)))) # sigmoid 545 | 546 | errG = torch.neg(torch.mean(errG)) 547 | errG.backward(one) 548 | optimizerG.step() 549 | loss_G += errG.data 550 | gen_iterations += 1 551 | 552 | # print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f' 553 | # % (epoch, opt.niter, i, len(dataloader), gen_iterations, 554 | # errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0])) 555 | if epoch % num_epochs_img == 0: 556 | if opt.normalizeImages: 557 | real_cpu = real_cpu.mul(0.5).add(0.5) 558 | 559 | vutils.save_image(real_cpu[0:num_images], '{0}/real_samples.png'.format(opt.experiment)) 560 | fake = netGenerator(Variable(fixed_noise, volatile=True)) 561 | if opt.normalizeImages: 562 | fake.data = fake.data.mul(0.5).add(0.5) 563 | vutils.save_image(fake.data, 564 | '{0}/fake_samples_{1}_epoch_{2}.png'.format(opt.experiment, gen_iterations, epoch)) 565 | 566 | # do checkpointing 567 | end_time = time.time() 568 | logger.info('[%d/%d] : Loss_D: %f Loss_G: %f, running time: %f' 569 | % (epoch, opt.niter, 570 | loss_D[0], loss_G[0], (end_time - start_time))) 571 | torch.save(netGenerator.state_dict(), '{0}/netG_epoch_{1}.pth'.format(opt.experiment, epoch)) 572 | torch.save(netDiscriminator.state_dict(), '{0}/netD_epoch_{1}.pth'.format(opt.experiment, epoch)) 573 | val_logprob_normalized = eval_with_KDE(netGenerator, val_set, val_noise) 574 | if train_best_netG_model == '' or val_logprob_normalized > train_best_logprob: 575 | train_best_logprob = val_logprob_normalized 576 | train_best_netG_model = '{0}/netG_epoch_{1}.pth'.format(opt.experiment, epoch) 577 | logger.info( 578 | 'The current best model is epoch {0} with mean log probability {1}.'.format( 579 | epoch, val_logprob_normalized)) 580 | return train_best_netG_model, train_best_logprob 581 | 582 | if opt.task == 'eval_kde': 583 | return sample_image_compute_density(opt.start, opt.end) 584 | else: 585 | if opt.kdeEpoch == 0: 586 | if opt.D == 'gan': 587 | best_netG_model, best_logprob = eval_gan(netD, netG) 588 | else: 589 | best_netG_model, best_logprob = eval(netD, netG) 590 | else: 591 | best_netG_model = '{0}/netG_epoch_{1}.pth'.format(opt.experiment, opt.kdeEpoch) 592 | print('Load model from epoch %d for KDE evaluation.' % opt.kdeEpoch) 593 | logger.info('Load the best model from {0} with log probability {1} .'.format(best_netG_model, best_logprob)) 594 | netG.load_state_dict(torch.load(best_netG_model)) 595 | if opt.cuda: 596 | netG.cuda() 597 | fake = netG(Variable(fixed_noise, volatile=True)) 598 | if opt.normalizeImages: 599 | fake.data = fake.data.mul(0.5).add(0.5) 600 | vutils.save_image(fake.data, 601 | '{0}/best_fake_samples.png'.format(opt.experiment)) 602 | logprob_mean = eval_with_KDE(netG, test_set, test_noise) 603 | logger.info("On the test set, mean log probablity is %s " % logprob_mean) 604 | return logprob_mean 605 | 606 | 607 | 608 | def search_hyperparams(opt): 609 | learn_rates = [0.0002, 0.0001, 0.00005, 0.00001] 610 | hidden_units = [opt.imageSize, opt.imageSize * 2, opt.imageSize * 8, opt.imageSize * 32, 1200] 611 | clamping_bounds = [0, 0.1, 0.01, 0.001] 612 | batch_size = [64, 4096] 613 | init_weights = ['uniform', 'xavier'] 614 | init_noise = ['uniform_one', 'uniform_zero_one', 'gaussian'] 615 | 616 | original_exp = opt.experiment 617 | error_out = open(os.path.join(original_exp, '{0}_{1}_{2}_{3}_hyperparams.error'.format(opt.dataset, opt.D, opt.A, opt.H)), 'w') 618 | with open(os.path.join(original_exp, '{0}_{1}_{2}_{3}_hyperparams.log'.format(opt.dataset, opt.D, opt.A, opt.H)), 'w') as out: 619 | max_logprob = 0 620 | best_config = '' 621 | if opt.adam: 622 | for lr in learn_rates: 623 | opt.lrD = lr 624 | opt.lrG = lr 625 | opt.experiment = os.path.join(original_exp, 626 | '{0}_{1}_{2}_{3}_{4}'.format(opt.dataset, opt.D, opt.A, opt.H, lr)) 627 | if not os.path.exists(opt.experiment): 628 | os.makedirs(opt.experiment) 629 | try: 630 | logprob = train(opt=opt, log_file_path=os.path.join(opt.experiment, 631 | '{0}_{1}_{2}_{3}_evaluation.log'.format( 632 | opt.dataset, opt.D, opt.A, opt.H))) 633 | config = '{0}\t{1}\n'.format(lr, logprob) 634 | if max_logprob == 0 or logprob > max_logprob: 635 | max_logprob = logprob 636 | best_config = config 637 | out.write(config) 638 | out.flush() 639 | except Exception as e: 640 | print(e) 641 | error_out.write('learning rate {0} with error {1}'.format(lr, e)) 642 | print(best_config) 643 | elif opt.A == 'dcgan' and opt.D != 'wgan': 644 | for lr in learn_rates: 645 | opt.lrD = lr 646 | opt.lrG = lr 647 | for c in clamping_bounds: 648 | opt.clamp_lower = -c 649 | opt.clamp_upper = c 650 | opt.experiment = os.path.join(original_exp, '{0}_{1}_{2}_{3}_{4}_{5}'.format(opt.dataset, opt.D, opt.A, opt.H, lr, c)) 651 | if not os.path.exists(opt.experiment): 652 | os.makedirs(opt.experiment) 653 | try: 654 | logprob = train(opt=opt, log_file_path=os.path.join(opt.experiment, '{0}_{1}_{2}_{3}_evaluation.log'.format(opt.dataset, opt.D, opt.A, opt.H))) 655 | config = '{0}\t{1}\t{2}\n'.format(lr, c, logprob) 656 | if max_logprob == 0 or logprob > max_logprob: 657 | max_logprob = logprob 658 | best_config = config 659 | out.write(config) 660 | out.flush() 661 | except Exception as e: 662 | print(e) 663 | error_out.write('config {0} {1} with error {2}'.format(lr, c, e)) 664 | print(best_config) 665 | elif opt.D == 'wgan': 666 | opt.clamp_lower = -0.01 667 | opt.clamp_upper = 0.01 668 | for lr in learn_rates: 669 | opt.lrD = lr 670 | opt.lrG = lr 671 | opt.experiment = os.path.join(original_exp, '{0}_{1}_{2}_{3}_{4}'.format(opt.dataset, opt.D, opt.A, opt.H, lr)) 672 | if not os.path.exists(opt.experiment): 673 | os.makedirs(opt.experiment) 674 | try: 675 | logprob = train(opt=opt, log_file_path=os.path.join(opt.experiment, '{0}_{1}_{2}_{3}_evaluation.log'.format(opt.dataset, opt.D, opt.A, opt.H))) 676 | config = '{0}\t{1}\n'.format(lr, logprob) 677 | if max_logprob == 0 or logprob > max_logprob: 678 | max_logprob = logprob 679 | best_config = config 680 | out.write(config) 681 | out.flush() 682 | except Exception as e: 683 | print(e) 684 | error_out.write('learning rate {0} with error {1}'.format(lr, e)) 685 | print(best_config) 686 | else: 687 | for lr in learn_rates: 688 | opt.lrD = lr 689 | opt.lrG = lr 690 | for h in hidden_units: 691 | opt.ngf = h 692 | opt.ndf = h 693 | for c in clamping_bounds: 694 | opt.clamp_lower = -c 695 | opt.clamp_upper = c 696 | opt.experiment = os.path.join(original_exp, 697 | '{0}_{1}_{2}_{3}_{4}_{5}_{6}'.format(opt.dataset, opt.D, opt.A, 698 | opt.H, lr, h, c)) 699 | if not os.path.exists(opt.experiment): 700 | os.makedirs(opt.experiment) 701 | try: 702 | 703 | logprob = train(opt=opt, log_file_path=os.path.join(opt.experiment, '{0}_{1}_{2}_{3}_evaluation.log'.format(opt.dataset, opt.D, opt.A, opt.H))) 704 | config = '{0}\t{1}\t{2}\t{3}\n'.format(lr, h, c, logprob) 705 | if max_logprob == 0 or logprob > max_logprob: 706 | max_logprob = logprob 707 | best_config = config 708 | out.write(config) 709 | out.flush() 710 | except Exception as e: 711 | print(e) 712 | error_out.write('config {0} {1} {2} with error {3}'.format(lr, h, c, e)) 713 | print(best_config) 714 | out.write('best configuration : {0}'.format(best_config)) 715 | error_out.close() 716 | 717 | def search_lsun_hyperparams(opt): 718 | learn_rates = [0.0001, 0.00005, 0.00001] 719 | hidden_units = [opt.imageSize, 1024] 720 | clamping_bounds = [0, 0.1, 0.01, 0.001] 721 | original_exp = opt.experiment 722 | error_out = open(os.path.join(original_exp, '{0}_{1}_{2}_{3}_hyperparams.error'.format(opt.dataset, opt.D, opt.A, opt.H)), 'w') 723 | with open(os.path.join(original_exp, '{0}_{1}_{2}_{3}_hyperparams.log'.format(opt.dataset, opt.D, opt.A, opt.H)), 'w') as out: 724 | max_logprob = 0 725 | best_config = '' 726 | if opt.A == 'dcgan': 727 | for lr in learn_rates: 728 | opt.lrD = lr 729 | opt.lrG = lr 730 | for c in clamping_bounds: 731 | opt.clamp_lower = -c 732 | opt.clamp_upper = c 733 | opt.experiment = os.path.join(original_exp, '{0}_{1}_{2}_{3}_{4}_{5}'.format(opt.dataset, opt.D, opt.A, opt.H, lr, c)) 734 | if not os.path.exists(opt.experiment): 735 | os.makedirs(opt.experiment) 736 | try: 737 | logprob = train(opt=opt, log_file_path=os.path.join(opt.experiment, '{0}_{1}_{2}_{3}_evaluation.log'.format(opt.dataset, opt.D, opt.A, opt.H))) 738 | config = '{0}\t{1}\t{2}\n'.format(lr, c, logprob) 739 | if max_logprob == 0 or logprob > max_logprob: 740 | max_logprob = logprob 741 | best_config = config 742 | out.write(config) 743 | out.flush() 744 | except Exception as e: 745 | print(e) 746 | error_out.write('config {0} {1} with error {2}'.format(lr, c, e)) 747 | print(best_config) 748 | else: 749 | for lr in learn_rates: 750 | opt.lrD = lr 751 | opt.lrG = lr 752 | for h in hidden_units: 753 | opt.ngf = h 754 | opt.ndf = h 755 | for c in clamping_bounds: 756 | opt.clamp_lower = -c 757 | opt.clamp_upper = c 758 | opt.experiment = os.path.join(original_exp, 759 | '{0}_{1}_{2}_{3}_{4}_{5}_{6}'.format(opt.dataset, opt.D, opt.A, 760 | opt.H, lr, h, c)) 761 | if not os.path.exists(opt.experiment): 762 | os.makedirs(opt.experiment) 763 | try: 764 | 765 | logprob = train(opt=opt, log_file_path=os.path.join(opt.experiment, '{0}_{1}_{2}_{3}_evaluation.log'.format(opt.dataset, opt.D, opt.A, opt.H))) 766 | config = '{0}\t{1}\t{2}\t{3}\n'.format(lr, h, c, logprob) 767 | if max_logprob == 0 or logprob > max_logprob: 768 | max_logprob = logprob 769 | best_config = config 770 | out.write(config) 771 | out.flush() 772 | except Exception as e: 773 | print(e) 774 | error_out.write('config {0} {1} {2} with error {3}'.format(lr, h, c, e)) 775 | print(best_config) 776 | out.write('best configuration : {0}'.format(best_config)) 777 | error_out.close() 778 | 779 | def read(file_path): 780 | mus = set() 781 | if os.path.exists(file_path): 782 | with open(file_path, 'r') as csvfile: 783 | reader = csv.reader(csvfile, delimiter=',') 784 | 785 | for row in reader: 786 | if len(row[0]) > 0 and len(row) > 1: 787 | mus.add(float(row[0])) 788 | return mus 789 | 790 | def search_mu(opt, start = 0, end = 11): 791 | mus = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1] 792 | original_exp = opt.experiment 793 | file_name = '{0}_{1}_{2}_{3}_{4}_{5}_search_mu.log'.format(opt.dataset, opt.D, opt.A, opt.H, opt.manualSeed, opt.critic_last_layer) 794 | csv_file = os.path.join(opt.experiment, file_name) 795 | mu_set = read(csv_file) 796 | with open(csv_file, 'a') as out: 797 | max_logprob = 0 798 | best_config = '' 799 | for i in range(start, end): 800 | mu = mus[i] 801 | if mu not in mu_set: 802 | opt.mu = mu 803 | opt.H = 'murelu' 804 | if mu == 1 : 805 | opt.H = 'relu' 806 | try: 807 | opt.experiment = os.path.join(original_exp, '{0}_{1}_{2}_{3}_{4}_{5}_{6}'.format(opt.dataset, opt.D, opt.A, opt.H, mu, opt.manualSeed, opt.critic_last_layer)) 808 | if not os.path.exists(opt.experiment): 809 | os.makedirs(opt.experiment) 810 | logprob = train(opt=opt, log_file_path=os.path.join(opt.experiment, '{0}_{1}_{2}_{3}_{4}_{5}_{6}_matsushita_mu.log'.format(opt.dataset, opt.D, opt.A, opt.H, mu, opt.manualSeed, opt.critic_last_layer))) 811 | config = '{0},{1}\n'.format(mu, logprob) 812 | if max_logprob == 0 or logprob > max_logprob: 813 | max_logprob = logprob 814 | best_config = config 815 | out.write(config) 816 | out.flush() 817 | print('best %s ' % best_config) 818 | except Exception as e: 819 | traceback.print_exc() 820 | #out.write('best configuration : {0}'.format(best_config)) 821 | 822 | def experiments_randseeds(opt, start = 0, end = 5): 823 | random_seeds = [1, 101, 512, 1001, 10001] 824 | original_exp = opt.experiment 825 | file_name = '{0}_{1}_{2}_{3}_{4}_experiments.csv'.format(opt.dataset, opt.D, opt.A, opt.H, opt.critic_last_layer) 826 | csv_file = os.path.join(opt.experiment, file_name) 827 | with open(csv_file, 'a') as out: 828 | max_logprob = 0 829 | best_config = '' 830 | for i in range(start, end): 831 | rand_seed = random_seeds[i] 832 | opt.manualSeed = rand_seed 833 | try: 834 | opt.experiment = os.path.join(original_exp, '{0}_{1}_{2}_{3}_{4}_{5}'.format(opt.dataset, opt.D, opt.A, opt.H, opt.manualSeed, opt.critic_last_layer)) 835 | if not os.path.exists(opt.experiment): 836 | os.makedirs(opt.experiment) 837 | logprob = train(opt=opt, log_file_path=os.path.join(opt.experiment, '{0}_{1}_{2}_{3}_{4}_{5}_experiments.log'.format(opt.dataset, opt.D, opt.A, opt.H, opt.manualSeed, opt.critic_last_layer))) 838 | config = '{0},{1}\n'.format(rand_seed, logprob) 839 | if max_logprob == 0 or logprob > max_logprob: 840 | max_logprob = logprob 841 | best_config = config 842 | out.write(config) 843 | out.flush() 844 | print('best %s ' % best_config) 845 | except: 846 | traceback.print_exc() 847 | 848 | def main(): 849 | parser = argparse.ArgumentParser() 850 | parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw | mnist') 851 | parser.add_argument('--subset', help='tower | bedroom') 852 | parser.add_argument('--dataroot', required=True, help='path to dataset') 853 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 854 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') 855 | parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') 856 | parser.add_argument('--gpu_id', type=int, default=-1, help='GPU id') 857 | parser.add_argument('--nc', type=int, default=3, help='input image channels') 858 | parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') 859 | parser.add_argument('--ngf', type=int, default=64) 860 | parser.add_argument('--ndf', type=int, default=64) 861 | parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for') 862 | parser.add_argument('--lrD', type=float, default=0.0001, help='learning rate for Critic, default=0.0001') # 0.00005 863 | parser.add_argument('--lrG', type=float, default=0.0001, help='learning rate for Generator, default=0.0001') # 0.00005 864 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 865 | parser.add_argument('--alpha', type=float, default=1, help='alpha for elu. default=1') 866 | parser.add_argument('--cuda' , action='store_true', help='enables cuda') 867 | parser.add_argument('--ngpu' , type=int, default=1, help='number of GPUs to use') 868 | parser.add_argument('--netG', default='', help="path to netG (to continue training)") 869 | parser.add_argument('--netD', default='', help="path to netD (to continue training)") 870 | parser.add_argument('--clamp_lower', type=float, default=-0.01) 871 | parser.add_argument('--clamp_upper', type=float, default=0.01) 872 | parser.add_argument('--Diters', type=int, default=5, help='number of D iters per each G iter') 873 | parser.add_argument('--noBN', action='store_true', help='use batchnorm or not (only for DCGAN)') 874 | parser.add_argument('--D', default='kl', help='kl | gan | wgan') 875 | parser.add_argument('--A', default='wmlp', help='architecture : dcgan | wmlp | mlp ') 876 | parser.add_argument('--H', default='relu', help='activation function in hidden layers : relu | murelu | elu | ls | sp') 877 | parser.add_argument('--c_activation', default='none', help='activation function in the hidden layers of critic: relu | elu') 878 | parser.add_argument('--n_extra_layers', type=int, default=0, help='Number of extra layers on gen and disc') 879 | parser.add_argument('--experiment', default=None, help='Where to store samples and models') 880 | parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)') 881 | parser.add_argument('--bandwidth', type=float, default=0, help='optimal bandwidth for KDE, default=0') 882 | parser.add_argument('--kdeEpoch', type=int, default=0, help='The epoch of the model that is loaded for KDE evaluation.') 883 | parser.add_argument('--weightDecay', type=float, default=0) 884 | parser.add_argument('--task', default='train', help='rand') 885 | parser.add_argument('--init_z', default='gaussian', help='uniform_one | uniform_zero_one | gaussian') 886 | parser.add_argument('--init_w', default='xavier', help='xavier | uniform') 887 | parser.add_argument('--normalizeImages', type=bool, default=False) 888 | parser.add_argument('--last_layer', default='sigmoid', help='none | sigmoid | tanh') 889 | parser.add_argument('--critic_last_layer', default='none', help='none | sigmoid | tanh | matsu') 890 | parser.add_argument('--mu', type=float, default=0, help='mu for matsushita') 891 | parser.add_argument('--manualSeed', type=int, default=512, help='random seed') 892 | parser.add_argument('--wganheuristics', type=bool, default=False) 893 | parser.add_argument('--kde_result_dir', default='', help='folder which stores the images and KDE results created by the generator') 894 | parser.add_argument('--start', type=int, default=0, 895 | help='The starting epoch of the model that is loaded for KDE evaluation.') 896 | parser.add_argument('--end', type=int, default=99, 897 | help='The last epoch of the model that is loaded for KDE evaluation.') 898 | opt = parser.parse_args() 899 | 900 | 901 | # mnist 902 | 903 | opt.cuda = True 904 | if opt.dataset == 'mnist': 905 | opt.imageSize = 32 906 | opt.last_layer = 'sigmoid' 907 | opt.niter = 100 908 | elif opt.dataset == 'lsun': 909 | opt.imageSize = 64 910 | opt.last_layer = 'tanh' 911 | 912 | if opt.A == 'mlp': 913 | if opt.D == 'gan': 914 | opt.lrD = 0.0002 915 | opt.lrG = 0.0002 916 | opt.c_activation = 'elu' 917 | opt.init_w = 'xavier' 918 | opt.init_z = 'uniform_zero_one' 919 | opt.batchSize = 64 920 | opt.clamp_lower = 0 921 | opt.clamp_upper = 0 922 | opt.adam = True 923 | opt.ndf = 1024 924 | opt.ngf = 1024 925 | if opt.critic_last_layer == 'none': 926 | opt.critic_last_layer = 'sigmoid' 927 | 928 | elif opt.D == 'wgan': 929 | opt.lrD = 0.0002 930 | opt.lrG = 0.0002 931 | opt.c_activation = 'elu' 932 | opt.init_w = 'xavier' 933 | opt.init_z = 'uniform_zero_one' 934 | opt.batchSize = 64 935 | opt.clamp_lower = -0.01 936 | opt.clamp_upper = 0.01 937 | opt.adam = False 938 | opt.ndf = 1024 939 | opt.ngf = 1024 940 | elif opt.A == 'dcgan': 941 | if opt.D == 'gan': 942 | opt.lrD = 0.0002 943 | opt.lrG = 0.0002 944 | opt.init_z = 'gaussian' 945 | opt.batchSize = 64 946 | opt.clamp_lower = 0 947 | opt.clamp_upper = 0 948 | opt.adam = True 949 | if opt.critic_last_layer == 'none': 950 | opt.critic_last_layer = 'sigmoid' 951 | elif opt.D == 'wgan': 952 | opt.lrD = 0.0002 953 | opt.lrG = 0.0002 954 | opt.init_z = 'gaussian' 955 | opt.batchSize = 64 956 | opt.clamp_lower = -0.01 957 | opt.clamp_upper = 0.01 958 | opt.adam = False 959 | 960 | if opt.cuda: 961 | if opt.gpu_id >= 0: 962 | os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu_id) 963 | 964 | if opt.task == 'train': 965 | train(opt=opt, log_file_path=os.path.join(opt.experiment, '{0}_{1}_{2}_{3}_{4}_{5}.log'.format(opt.D, opt.A, opt.H, opt.c_activation, opt.manualSeed, opt.critic_last_layer))) 966 | elif opt.task == 'eval_kde': 967 | print('max log prob is %s ' % train(opt=opt, log_file_path=os.path.join(opt.kde_result_dir, '{0}_{1}_{2}.log'.format(opt.D, opt.A, opt.H)))) 968 | elif opt.task == 'mu': 969 | search_mu(opt=opt) 970 | elif opt.task == 'murange': 971 | search_mu(opt=opt, start = opt.start, end = opt.end) 972 | elif opt.task == 'rand': 973 | experiments_randseeds(opt=opt, start=opt.start, end=opt.end) 974 | else: 975 | assert opt.task == 'hyper' 976 | if opt.dataset == 'lsun': 977 | search_lsun_hyperparams(opt=opt) 978 | else: 979 | search_hyperparams(opt=opt) 980 | 981 | 982 | if __name__ == '__main__': 983 | main() --------------------------------------------------------------------------------