├── .gitignore ├── README.md ├── config.py ├── data_loader.py ├── imgs ├── objective.png ├── objective_graph.png └── pytorch.jpg ├── main.py ├── networks ├── Discriminator.py ├── Generator.py └── __init__.py └── scripts ├── cell.sh ├── cifar10.sh └── mnist.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LSGAN.pytorch 2 | Repository for Pytorch Implementation of Least Squares Generative Adversarial Networks 3 | 4 | ![alt_tag](./imgs/pytorch.jpg) 5 | 6 | ## Least Squares Generative Adversarial Networks 7 | 8 | Regular GANs hypothesize the discriminator as a classifier with the sigmoid cross entropy loss function. 9 | 10 | This loss function, however, may lead to the [vanishing gradient problem](https://en.wikipedia.org/wiki/Vanishing_gradient_problem) during the learning process. 11 | 12 | LSGANs(Least Squares GAN) adopt the least squares loss function for the discriminator. 13 | 14 | The paper shows that the objective function for LSGAN yields minimizing the [Pearson chi-square divergence](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.518.7353&rep=rep1&type=pdf). 15 | 16 | ![alt_tag](./imgs/objective_graph.png) 17 | ![alt_tag](./imgs/objective.png) 18 | 19 | ## Basic Setups 20 | 21 | ## How to run 22 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | ############## Deep Convolutional Generative Adversarial Networks ################ 2 | 3 | mean = { 4 | 'cifar10': (0.4914, 0.4822, 0.4465), 5 | 'mnist' : (0.5, 0.5, 0.5), 6 | 'cell' : (0.5, 0.5, 0.5), 7 | } 8 | 9 | std = { 10 | 'cifar10': (0.2023, 0.1994, 0.2010), 11 | 'mnist' : (0.5, 0.5, 0.5), 12 | 'cell' : (0.5, 0.5, 0.5), 13 | } 14 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data 3 | from torchvision import transforms 4 | from PIL import Image 5 | 6 | class ImageFolder(data.Dataset): 7 | def __init__(self, root, transform=None): 8 | self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root))) 9 | self.transform = transform 10 | 11 | def __getitem__(self, index): 12 | image_path = self.image_paths[index] 13 | image = Image.open(image_path).convert('RGB') 14 | if self.transform is not None: 15 | image = self.transform(image) 16 | return image 17 | 18 | def __len__(self): 19 | return len(self.image_paths) 20 | 21 | def get_loader(image_path, image_size, batch_size, transform, num_workers=2): 22 | dataset = ImageFolder(image_path, transform) 23 | data_laoder = data.DataLoader( 24 | dataset = dataset, 25 | batch_size = batch_size, 26 | shuffle = True, 27 | num_workers = num_workers 28 | ) 29 | 30 | return data_loader 31 | -------------------------------------------------------------------------------- /imgs/objective.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmsookim/LSGAN.pytorch/32401dffb5fe61c8c7b93d1b01817ba40c58a239/imgs/objective.png -------------------------------------------------------------------------------- /imgs/objective_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmsookim/LSGAN.pytorch/32401dffb5fe61c8c7b93d1b01817ba40c58a239/imgs/objective_graph.png -------------------------------------------------------------------------------- /imgs/pytorch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmsookim/LSGAN.pytorch/32401dffb5fe61c8c7b93d1b01817ba40c58a239/imgs/pytorch.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | import torch.backends.cudnn as cudnn 8 | 9 | import torchvision 10 | import torchvision.datasets as dset 11 | import torchvision.transforms as transforms 12 | import torchvision.utils as vutils 13 | import config as cf 14 | 15 | import os 16 | import sys 17 | import time 18 | import argparse 19 | import datetime 20 | import random 21 | 22 | from torch.autograd import Variable 23 | from data_loader import ImageFolder 24 | from networks import * 25 | 26 | # Parser 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--dataset', required=True, help='mnist | cifar10') 29 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') 30 | parser.add_argument('--imageSize', type=int, default=64, help='the width & height of the input image') 31 | 32 | parser.add_argument('--nz', type=int, default=100, help='size of latent z vector') 33 | parser.add_argument('--ngf', type=int, default=64, help='number of generator filters') 34 | parser.add_argument('--ndf', type=int, default=64, help='number of discriminator filters') 35 | 36 | parser.add_argument('--nEpochs', type=int, default=25, help='number of epochs to train for') 37 | parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=2e-4') 38 | parser.add_argument('--beta', type=float, default=0.5, help='beta1 for adam optimizer, default=0.5') 39 | 40 | parser.add_argument('--nGPU', type=int, default=2, help='number of GPUs to use') 41 | parser.add_argument('--outf', default='./checkpoints/', help='folder to output images and model checkpoints') 42 | parser.add_argument('--manualSeed', type=int, help='manual seed') 43 | 44 | opt = parser.parse_args() 45 | 46 | try: 47 | os.makedirs(opt.outf) 48 | except OSError: 49 | pass 50 | 51 | if opt.manualSeed is None: 52 | opt.manualSeed = random.randint(1, 10000) 53 | random.seed(opt.manualSeed) 54 | torch.manual_seed(opt.manualSeed) 55 | torch.cuda.manual_seed_all(opt.manualSeed) 56 | 57 | use_cuda = torch.cuda.is_available() 58 | cudnn.benchmark = True 59 | use_cuda = torch.cuda.is_available() 60 | 61 | ######################### Data Preperation 62 | print("\n[Phase 1] : Data Preperation") 63 | print("| Preparing %s dataset..." %(opt.dataset)) 64 | 65 | dset_transforms = transforms.Compose([ 66 | transforms.Scale(opt.imageSize), 67 | transforms.ToTensor(), 68 | transforms.Normalize(cf.mean[opt.dataset], cf.std[opt.dataset]) 69 | ]) 70 | 71 | if (opt.dataset == 'cifar10'): 72 | dataset = dset.CIFAR10( 73 | root='/home/bumsoo/Data/GAN/cifar10/', 74 | download=True, 75 | transform=dset_transforms 76 | ) 77 | elif (opt.dataset == 'mnist'): 78 | dataset = dset.MNIST( 79 | root='/home/bumsoo/Data/GAN/mnist/', 80 | download=True, 81 | transform=dset_transforms 82 | ) 83 | elif (opt.dataset == 'cell') : 84 | dataset = ImageFolder( 85 | root='/home/bumsoo/Data/GAN/cell/', 86 | transform=dset_transforms 87 | ) 88 | else: 89 | print("Error | Dataset must be one of mnist | cifar10") 90 | sys.exit(1) 91 | 92 | print("| Consisting data loader for %s..." %(opt.dataset)) 93 | loader = torch.utils.data.DataLoader( 94 | dataset = dataset, 95 | batch_size = opt.batchSize, 96 | shuffle = True 97 | ) 98 | 99 | ######################### Model Setup 100 | print("\n[Phase 2] : Model Setup") 101 | ndf = opt.ndf 102 | ngf = opt.ngf 103 | 104 | if(opt.dataset == 'cifar10'): 105 | nc = 3 106 | elif(opt.dataset == 'cell'): 107 | nc = 3 108 | elif(opt.dataset == 'mnist'): 109 | nc = 1 110 | else: 111 | print("Error : Dataset must be one of \'mnist | cifar10 | cell\'") 112 | sys.exit(1) 113 | 114 | print("| Consisting Discriminator with ndf=%d" %ndf) 115 | print("| Consisting Generator with z=%d" %opt.nz) 116 | netD = Discriminator(ndf, nc) 117 | netG = Generator(opt.nz, ngf, nc) 118 | 119 | if(use_cuda): 120 | netD.cuda() 121 | netG.cuda() 122 | cudnn.benchmark = True 123 | 124 | ######################### Loss & Optimizer 125 | criterion = nn.BCELoss() 126 | optimizerD = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta, 0.999)) 127 | optimizerG = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta, 0.999)) 128 | 129 | ######################### Global Variables 130 | noise = torch.FloatTensor(opt.batchSize, opt.nz, 1, 1) 131 | real = torch.FloatTensor(opt.batchSize, nc, opt.imageSize, opt.imageSize) 132 | label = torch.FloatTensor(opt.batchSize) 133 | real_label, fake_label = 1, 0 134 | 135 | noise = Variable(noise) 136 | real = Variable(real) 137 | label = Variable(label) 138 | 139 | if(use_cuda): 140 | noise = noise.cuda() 141 | real = real.cuda() 142 | label = label.cuda() 143 | 144 | ######################### Training Stage 145 | print("\n[Phase 4] : Train model") 146 | for epoch in range(1, opt.nEpochs+1): 147 | for i, (images) in enumerate(loader): # We don't need the class label information 148 | 149 | ######################### fDx : Gradient of Discriminator 150 | netD.zero_grad() 151 | 152 | # train with real data 153 | real.data.resize_(images.size()).copy_(images) 154 | label.data.resize_(images.size(0)).fill_(real_label) 155 | 156 | output = netD(real) # Forward propagation, this should result in '1' 157 | errD_real = 0.5 * torch.mean((output-label)**2) # criterion(output, label) 158 | errD_real.backward() 159 | 160 | # train with fake data 161 | label.data.fill_(fake_label) 162 | noise.data.resize_(images.size(0), opt.nz, 1, 1) 163 | noise.data.normal_(0, 1) 164 | 165 | fake = netG(noise) # Create fake image 166 | output = netD(fake.detach()) # Forward propagation for fake, this should result in '0' 167 | errD_fake = 0.5 * torch.mean((output-label)**2) # criterion(output, label) 168 | errD_fake.backward() 169 | #### Appendix #### 170 | #### var.detach() = Variable(var.data), difference in computing trigger 171 | 172 | errD = errD_fake + errD_real 173 | optimizerD.step() 174 | 175 | ######################### fGx : Gradient of Generator 176 | netG.zero_grad() 177 | label.data.fill_(real_label) 178 | output = netD(fake) # Forward propagation of generated image, this should result in '1' 179 | errG = 0.5 * torch.mean((output - label)**2) # criterion(output, label) 180 | errG.backward() 181 | optimizerG.step() 182 | 183 | ######################### LOG 184 | sys.stdout.write('\r') 185 | sys.stdout.write('| Epoch [%2d/%2d] Iter [%3d/%3d] Loss(D): %.4f Loss(G): %.4f ' 186 | %(epoch, opt.nEpochs, i, len(loader), errD.data[0], errG.data[0])) 187 | sys.stdout.flush() 188 | 189 | ######################### Visualize 190 | if(i%1 == 0): 191 | print(": Saving current results...") 192 | vutils.save_image( 193 | fake.data, 194 | '%s/fake_samples_%03d.png' %(opt.outf, epoch), 195 | normalize=True 196 | ) 197 | 198 | ######################### Save model 199 | torch.save(netG.state_dict(), '%s/netG.pth' %(opt.outf)) 200 | torch.save(netD.state_dict(), '%s/netD.pth' %(opt.outf)) 201 | -------------------------------------------------------------------------------- /networks/Discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | 7 | class Discriminator(nn.Module): 8 | def __init__(self, ndf, nChannels): 9 | super(Discriminator, self).__init__() 10 | # input : (batch * nChannels * image width * image height) 11 | # Discriminator will be consisted with a series of convolution networks 12 | 13 | self.layer1 = nn.Sequential( 14 | # Input size : input image with dimension (nChannels)*64*64 15 | # Output size: output feature vector with (ndf)*32*32 16 | nn.Conv2d( 17 | in_channels = nChannels, 18 | out_channels = ndf, 19 | kernel_size = 4, 20 | stride = 2, 21 | padding = 1, 22 | bias = False 23 | ), 24 | nn.BatchNorm2d(ndf), 25 | nn.LeakyReLU(0.2, inplace=True) 26 | ) 27 | 28 | self.layer2 = nn.Sequential( 29 | # Input size : input feature vector with (ndf)*32*32 30 | # Output size: output feature vector with (ndf*2)*16*16 31 | nn.Conv2d( 32 | in_channels = ndf, 33 | out_channels = ndf*2, 34 | kernel_size = 4, 35 | stride = 2, 36 | padding = 1, 37 | bias = False 38 | ), 39 | nn.BatchNorm2d(ndf*2), 40 | nn.LeakyReLU(0.2, inplace=True) 41 | ) 42 | 43 | self.layer3 = nn.Sequential( 44 | # Input size : input feature vector with (ndf*2)*16*16 45 | # Output size: output feature vector with (ndf*4)*8*8 46 | nn.Conv2d( 47 | in_channels = ndf*2, 48 | out_channels = ndf*4, 49 | kernel_size = 4, 50 | stride = 2, 51 | padding = 1, 52 | bias = False 53 | ), 54 | nn.BatchNorm2d(ndf*4), 55 | nn.LeakyReLU(0.2, inplace=True) 56 | ) 57 | 58 | self.layer4 = nn.Sequential( 59 | # Input size : input feature vector with (ndf*4)*8*8 60 | # Output size: output feature vector with (ndf*8)*4*4 61 | nn.Conv2d( 62 | in_channels = ndf*4, 63 | out_channels = ndf*8, 64 | kernel_size = 4, 65 | stride = 2, 66 | padding = 1, 67 | bias = False 68 | ), 69 | nn.BatchNorm2d(ndf*8), 70 | nn.LeakyReLU(0.2, inplace=True) 71 | ) 72 | 73 | self.layer5 = nn.Sequential( 74 | # Input size : input feature vector with (ndf*8)*4*4 75 | # Output size: output probability of fake/real image 76 | nn.Conv2d( 77 | in_channels = ndf*8, 78 | out_channels = 1, 79 | kernel_size = 4, 80 | stride = 1, 81 | padding = 0, 82 | bias = False 83 | ), 84 | # nn.Sigmoid() -- Replaced with Least Square Loss 85 | ) 86 | 87 | def forward(self, x): 88 | out = self.layer1(x) 89 | out = self.layer2(out) 90 | out = self.layer3(out) 91 | out = self.layer4(out) 92 | out = self.layer5(out) 93 | 94 | return out.view(-1,1) 95 | 96 | if __name__ == "__main__": 97 | net = Discriminator( 98 | nChannels = 3, 99 | ndf = 64 100 | ) 101 | print "Input(=image) : ", 102 | print(torch.randn(128,3,64,64).size()) 103 | y = net(Variable(torch.randn(128,3,64,64))) # Input should be a 4D tensor 104 | print "Output(batchsize, channels, width, height) : ", 105 | print(y.size()) 106 | -------------------------------------------------------------------------------- /networks/Generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | 7 | class Generator(nn.Module): 8 | def __init__(self, nz, ngf, nChannels): 9 | super(Generator, self).__init__() 10 | # input : z 11 | # Generator will be consisted with a series of deconvolution networks 12 | 13 | self.layer1 = nn.Sequential( 14 | # input : z 15 | # Generator will be consisted with a series of deconvolution networks 16 | 17 | # Input size : input latent vector 'z' with dimension (nz)*1*1 18 | # Output size: output feature vector with (ngf*8)*4*4 19 | nn.ConvTranspose2d( 20 | in_channels = nz, 21 | out_channels = ngf*8, 22 | kernel_size = 4, 23 | stride = 1, 24 | padding = 0, 25 | bias = False 26 | ), 27 | nn.BatchNorm2d(ngf*8), 28 | nn.ReLU(True) 29 | ) 30 | 31 | self.layer2 = nn.Sequential( 32 | # Input size : input feature vector with (ngf*8)*4*4 33 | # Output size: output feature vector with (ngf*4)*8*8 34 | nn.ConvTranspose2d( 35 | in_channels = ngf*8, 36 | out_channels = ngf*4, 37 | kernel_size = 4, 38 | stride = 2, 39 | padding = 1, 40 | bias = False 41 | ), 42 | nn.BatchNorm2d(ngf*4), 43 | nn.ReLU(True) 44 | ) 45 | 46 | self.layer3 = nn.Sequential( 47 | # Input size : input feature vector with (ngf*4)*8*8 48 | # Output size: output feature vector with (ngf*2)*16*16 49 | nn.ConvTranspose2d( 50 | in_channels = ngf*4, 51 | out_channels = ngf*2, 52 | kernel_size = 4, 53 | stride = 2, 54 | padding = 1, 55 | bias = False 56 | ), 57 | nn.BatchNorm2d(ngf*2), 58 | nn.ReLU(True) 59 | ) 60 | 61 | self.layer4 = nn.Sequential( 62 | # Input size : input feature vector with (ngf*2)*16*16 63 | # Output size: output feature vector with (ngf)*32*32 64 | nn.ConvTranspose2d( 65 | in_channels = ngf*2, 66 | out_channels = ngf, 67 | kernel_size = 4, 68 | stride = 2, 69 | padding = 1, 70 | bias = False 71 | ), 72 | nn.BatchNorm2d(ngf), 73 | nn.ReLU(True) 74 | ) 75 | 76 | self.layer5 = nn.Sequential( 77 | # Input size : input feature vector with (ngf)*32*32 78 | # Output size: output image with (nChannels)*(image width)*(image height) 79 | nn.ConvTranspose2d( 80 | in_channels = ngf, 81 | out_channels = nChannels, 82 | kernel_size =4, 83 | stride = 2, 84 | padding = 1, 85 | bias = False 86 | ), 87 | nn.Tanh() # To restrict each pixels of the fake image to 0~1 88 | # Yunjey seems to say that this does not matter much 89 | ) 90 | 91 | def forward(self, x): 92 | out = self.layer1(x) 93 | out = self.layer2(out) 94 | out = self.layer3(out) 95 | out = self.layer4(out) 96 | out = self.layer5(out) 97 | 98 | return out 99 | 100 | if __name__ == '__main__': 101 | net = Generator( 102 | nz = 100, 103 | ngf = 64, 104 | nChannels = 1 105 | ) 106 | print "Input(=z) : ", 107 | print(torch.randn(128,100,1,1).size()) 108 | y = net(Variable(torch.randn(128,100,1,1))) # Input should be a 4D tensor 109 | print "Output(batchsize, channels, width, height) : ", 110 | print(y.size()) 111 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .Discriminator import * 2 | from .Generator import * 3 | -------------------------------------------------------------------------------- /scripts/cell.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --dataset cell \ 3 | --lr 1e-3 \ 4 | --nEpochs 300 5 | -------------------------------------------------------------------------------- /scripts/cifar10.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --dataset cifar10 3 | -------------------------------------------------------------------------------- /scripts/mnist.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --dataset mnist 3 | --------------------------------------------------------------------------------