├── .gitignore ├── LICENSE ├── README.md ├── figs ├── architecture.png ├── fake_samples_epoch_470.png └── fake_samples_epoch_499.png ├── folder.py ├── main.py ├── network.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | outputs* 2 | *.pyc 3 | *.sw[po] 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Te-Lin Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conditional Image Synthesis With Auxiliary Classifier GANs 2 | 3 | As part of the implementation series of [Joseph Lim's group at USC](http://csail.mit.edu/~lim), our motivation is to accelerate (or sometimes delay) research in the AI community by promoting open-source projects. To this end, we implement state-of-the-art research papers, and publicly share them with concise reports. Please visit our [group github site](https://github.com/gitlimlab) for other projects. 4 | 5 | This project is implemented by [Te-Lin Wu](https://github.com/telin0411) and the codes have been reviewed by [Shao-Hua Sun](https://github.com/shaohua0116) before being published. 6 | 7 | ## Descriptions 8 | This project is a [PyTorch](http://pytorch.org) implementation of [Conditional Image Synthesis With Auxiliary Classifier GANs](https://arxiv.org/abs/1610.09585) which was published as a conference proceeding at ICML 2017. This paper proposes a simple extention of GANs that employs label conditioning in additional to produce high resolution and high quality generated images. 9 | 10 | By adding an auxiliary classifier to the discriminator of a GAN, the discriminator produces not only a probability distribution over sources but also probability distribution over the class labels. This simple modification to the standard DCGAN models does not give tremendous difference but produces better results and is capable of stabilizing the whole adversarial training. 11 | 12 | The architecture is as shown below for comparisons of several GANs. 13 |

14 | 15 |

16 | 17 | The sample generated images from ImageNet dataset. 18 |

19 | 20 |

21 | 22 | The sample generated images from CIFAR-10 dataset. 23 |

24 | 25 |

26 | 27 | The implemented model can be trained on both [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) and [ImageNet](http://www.image-net.org) datasets. 28 | 29 | Note that this implementation may differ in details from the original paper such as model architectures, hyperparameters, applied optimizer, etc. while maintaining the main proposed idea. 30 | 31 | \*This code is still being developed and subject to change. 32 | 33 | ## Prerequisites 34 | 35 | - Python 2.7 36 | - [PyTorch](http://pytorch.org) 37 | - [SciPy](http://www.scipy.org/install.html) 38 | - [NumPy](http://www.numpy.org/) 39 | - [PIL](http://pillow.readthedocs.io/en/3.1.x/installation.html) 40 | - [imageio](https://imageio.github.io/) 41 | 42 | ## Usage 43 | Run the following command for details of each arguments. 44 | ```bash 45 | $ python main.py -h 46 | ``` 47 | You should specify the path to the dataset you are using with argument --dataroot, the code will automatically check if you have cifar10 dataset downloaded or not. If not, the code will download it for you. For the ImageNet training you should download the whole dataset on their website, this repository used 2012 version for the training. And you should point the dataroot to the train (or val) directory as the root directory for ImageNet training. 48 | 49 | In line 80 of main.py, you can change the classes\_idx argument to take into other user-specified imagenet classes, and adjust the num\_classes accordingly if it is not 10. 50 | ```python 51 | if opt.dataset == 'imagenet': 52 | # folder dataset 53 | dataset = ImageFolder(root=opt.dataroot, 54 | transform=transforms.Compose([ 55 | transforms.Scale(opt.imageSize), 56 | transforms.CenterCrop(opt.imageSize), 57 | transforms.ToTensor(), 58 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 59 | ]), 60 | classes_idx=(10,20)) 61 | ``` 62 | 63 | ### Train the models 64 | Example training commands, the code will automatically generate images for testing during training to the --outf directory. 65 | ```bash 66 | $ python main.py --outf=/your/output/file/name --niter=500 --batchSize=100 --cuda --dataset=cifar10 --imageSize=32 --dataroot=/data/path/to/cifar10 --gpu=0 67 | ``` 68 | 69 | ## Author 70 | 71 | Te-Lin Wu / [@telin0411](https://github.com/telin0411) @ [Joseph Lim's research lab](https://github.com/gitlimlab) @ USC 72 | -------------------------------------------------------------------------------- /figs/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/ACGAN-PyTorch/4bd0405a4bd90b07548d4b84b57e24767d7cbf65/figs/architecture.png -------------------------------------------------------------------------------- /figs/fake_samples_epoch_470.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/ACGAN-PyTorch/4bd0405a4bd90b07548d4b84b57e24767d7cbf65/figs/fake_samples_epoch_470.png -------------------------------------------------------------------------------- /figs/fake_samples_epoch_499.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/ACGAN-PyTorch/4bd0405a4bd90b07548d4b84b57e24767d7cbf65/figs/fake_samples_epoch_499.png -------------------------------------------------------------------------------- /folder.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | 7 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] 8 | 9 | 10 | def is_image_file(filename): 11 | """Checks if a file is an image. 12 | 13 | Args: 14 | filename (string): path to a file 15 | 16 | Returns: 17 | bool: True if the filename ends with a known image extension 18 | """ 19 | filename_lower = filename.lower() 20 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) 21 | 22 | 23 | def find_classes(dir, classes_idx=None): 24 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 25 | classes.sort() 26 | if classes_idx is not None: 27 | assert type(classes_idx) == tuple 28 | start, end = classes_idx 29 | classes = classes[start:end] 30 | class_to_idx = {classes[i]: i for i in range(len(classes))} 31 | return classes, class_to_idx 32 | 33 | 34 | def make_dataset(dir, class_to_idx): 35 | images = [] 36 | dir = os.path.expanduser(dir) 37 | for target in sorted(os.listdir(dir)): 38 | if target not in class_to_idx: 39 | continue 40 | d = os.path.join(dir, target) 41 | if not os.path.isdir(d): 42 | continue 43 | 44 | for root, _, fnames in sorted(os.walk(d)): 45 | for fname in sorted(fnames): 46 | if is_image_file(fname): 47 | path = os.path.join(root, fname) 48 | item = (path, class_to_idx[target]) 49 | images.append(item) 50 | 51 | return images 52 | 53 | 54 | def pil_loader(path): 55 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 56 | with open(path, 'rb') as f: 57 | with Image.open(f) as img: 58 | return img.convert('RGB') 59 | 60 | 61 | def accimage_loader(path): 62 | import accimage 63 | try: 64 | return accimage.Image(path) 65 | except IOError: 66 | # Potentially a decoding problem, fall back to PIL.Image 67 | return pil_loader(path) 68 | 69 | 70 | def default_loader(path): 71 | from torchvision import get_image_backend 72 | if get_image_backend() == 'accimage': 73 | return accimage_loader(path) 74 | else: 75 | return pil_loader(path) 76 | 77 | 78 | class ImageFolder(data.Dataset): 79 | """A generic data loader where the images are arranged in this way: :: 80 | 81 | root/dog/xxx.png 82 | root/dog/xxy.png 83 | root/dog/xxz.png 84 | 85 | root/cat/123.png 86 | root/cat/nsdf3.png 87 | root/cat/asd932_.png 88 | 89 | Args: 90 | root (string): Root directory path. 91 | transform (callable, optional): A function/transform that takes in an PIL image 92 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 93 | target_transform (callable, optional): A function/transform that takes in the 94 | target and transforms it. 95 | loader (callable, optional): A function to load an image given its path. 96 | 97 | Attributes: 98 | classes (list): List of the class names. 99 | class_to_idx (dict): Dict with items (class_name, class_index). 100 | imgs (list): List of (image path, class_index) tuples 101 | """ 102 | 103 | def __init__(self, root, transform=None, target_transform=None, 104 | loader=default_loader, classes_idx=None): 105 | self.classes_idx = classes_idx 106 | classes, class_to_idx = find_classes(root, self.classes_idx) 107 | imgs = make_dataset(root, class_to_idx) 108 | if len(imgs) == 0: 109 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 110 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 111 | 112 | self.root = root 113 | self.imgs = imgs 114 | self.classes = classes 115 | self.class_to_idx = class_to_idx 116 | self.transform = transform 117 | self.target_transform = target_transform 118 | self.loader = loader 119 | 120 | def __getitem__(self, index): 121 | """ 122 | Args: 123 | index (int): Index 124 | 125 | Returns: 126 | tuple: (image, target) where target is class_index of the target class. 127 | """ 128 | path, target = self.imgs[index] 129 | img = self.loader(path) 130 | if self.transform is not None: 131 | img = self.transform(img) 132 | if self.target_transform is not None: 133 | target = self.target_transform(target) 134 | 135 | return img, target 136 | 137 | def __len__(self): 138 | return len(self.imgs) 139 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code modified from PyTorch DCGAN examples: https://github.com/pytorch/examples/tree/master/dcgan 3 | """ 4 | from __future__ import print_function 5 | import argparse 6 | import os 7 | import numpy as np 8 | import random 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim as optim 14 | import torch.utils.data 15 | import torchvision.datasets as dset 16 | import torchvision.transforms as transforms 17 | import torchvision.utils as vutils 18 | from torch.autograd import Variable 19 | from utils import weights_init, compute_acc 20 | from network import _netG, _netD, _netD_CIFAR10, _netG_CIFAR10 21 | from folder import ImageFolder 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--dataset', required=True, help='cifar10 | imagenet') 26 | parser.add_argument('--dataroot', required=True, help='path to dataset') 27 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 28 | parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 29 | parser.add_argument('--imageSize', type=int, default=128, help='the height / width of the input image to network') 30 | parser.add_argument('--nz', type=int, default=110, help='size of the latent z vector') 31 | parser.add_argument('--ngf', type=int, default=64) 32 | parser.add_argument('--ndf', type=int, default=64) 33 | parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for') 34 | parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') 35 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 36 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 37 | parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') 38 | parser.add_argument('--netG', default='', help="path to netG (to continue training)") 39 | parser.add_argument('--netD', default='', help="path to netD (to continue training)") 40 | parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints') 41 | parser.add_argument('--manualSeed', type=int, help='manual seed') 42 | parser.add_argument('--num_classes', type=int, default=10, help='Number of classes for AC-GAN') 43 | parser.add_argument('--gpu_id', type=int, default=0, help='The ID of the specified GPU') 44 | 45 | opt = parser.parse_args() 46 | print(opt) 47 | 48 | # specify the gpu id if using only 1 gpu 49 | if opt.ngpu == 1: 50 | os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu_id) 51 | 52 | try: 53 | os.makedirs(opt.outf) 54 | except OSError: 55 | pass 56 | 57 | if opt.manualSeed is None: 58 | opt.manualSeed = random.randint(1, 10000) 59 | print("Random Seed: ", opt.manualSeed) 60 | random.seed(opt.manualSeed) 61 | torch.manual_seed(opt.manualSeed) 62 | if opt.cuda: 63 | torch.cuda.manual_seed_all(opt.manualSeed) 64 | 65 | cudnn.benchmark = True 66 | 67 | if torch.cuda.is_available() and not opt.cuda: 68 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 69 | 70 | # datase t 71 | if opt.dataset == 'imagenet': 72 | # folder dataset 73 | dataset = ImageFolder( 74 | root=opt.dataroot, 75 | transform=transforms.Compose([ 76 | transforms.Scale(opt.imageSize), 77 | transforms.CenterCrop(opt.imageSize), 78 | transforms.ToTensor(), 79 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 80 | ]), 81 | classes_idx=(10, 20) 82 | ) 83 | elif opt.dataset == 'cifar10': 84 | dataset = dset.CIFAR10( 85 | root=opt.dataroot, download=True, 86 | transform=transforms.Compose([ 87 | transforms.Scale(opt.imageSize), 88 | transforms.ToTensor(), 89 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 90 | ])) 91 | else: 92 | raise NotImplementedError("No such dataset {}".format(opt.dataset)) 93 | 94 | assert dataset 95 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, 96 | shuffle=True, num_workers=int(opt.workers)) 97 | 98 | # some hyper parameters 99 | ngpu = int(opt.ngpu) 100 | nz = int(opt.nz) 101 | ngf = int(opt.ngf) 102 | ndf = int(opt.ndf) 103 | num_classes = int(opt.num_classes) 104 | nc = 3 105 | 106 | # Define the generator and initialize the weights 107 | if opt.dataset == 'imagenet': 108 | netG = _netG(ngpu, nz) 109 | else: 110 | netG = _netG_CIFAR10(ngpu, nz) 111 | netG.apply(weights_init) 112 | if opt.netG != '': 113 | netG.load_state_dict(torch.load(opt.netG)) 114 | print(netG) 115 | 116 | # Define the discriminator and initialize the weights 117 | if opt.dataset == 'imagenet': 118 | netD = _netD(ngpu, num_classes) 119 | else: 120 | netD = _netD_CIFAR10(ngpu, num_classes) 121 | netD.apply(weights_init) 122 | if opt.netD != '': 123 | netD.load_state_dict(torch.load(opt.netD)) 124 | print(netD) 125 | 126 | # loss functions 127 | dis_criterion = nn.BCELoss() 128 | aux_criterion = nn.NLLLoss() 129 | 130 | # tensor placeholders 131 | input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize) 132 | noise = torch.FloatTensor(opt.batchSize, nz, 1, 1) 133 | eval_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1) 134 | dis_label = torch.FloatTensor(opt.batchSize) 135 | aux_label = torch.LongTensor(opt.batchSize) 136 | real_label = 1 137 | fake_label = 0 138 | 139 | # if using cuda 140 | if opt.cuda: 141 | netD.cuda() 142 | netG.cuda() 143 | dis_criterion.cuda() 144 | aux_criterion.cuda() 145 | input, dis_label, aux_label = input.cuda(), dis_label.cuda(), aux_label.cuda() 146 | noise, eval_noise = noise.cuda(), eval_noise.cuda() 147 | 148 | # define variables 149 | input = Variable(input) 150 | noise = Variable(noise) 151 | eval_noise = Variable(eval_noise) 152 | dis_label = Variable(dis_label) 153 | aux_label = Variable(aux_label) 154 | # noise for evaluation 155 | eval_noise_ = np.random.normal(0, 1, (opt.batchSize, nz)) 156 | eval_label = np.random.randint(0, num_classes, opt.batchSize) 157 | eval_onehot = np.zeros((opt.batchSize, num_classes)) 158 | eval_onehot[np.arange(opt.batchSize), eval_label] = 1 159 | eval_noise_[np.arange(opt.batchSize), :num_classes] = eval_onehot[np.arange(opt.batchSize)] 160 | eval_noise_ = (torch.from_numpy(eval_noise_)) 161 | eval_noise.data.copy_(eval_noise_.view(opt.batchSize, nz, 1, 1)) 162 | 163 | # setup optimizer 164 | optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 165 | optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 166 | 167 | avg_loss_D = 0.0 168 | avg_loss_G = 0.0 169 | avg_loss_A = 0.0 170 | for epoch in range(opt.niter): 171 | for i, data in enumerate(dataloader, 0): 172 | ############################ 173 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 174 | ########################### 175 | # train with real 176 | netD.zero_grad() 177 | real_cpu, label = data 178 | batch_size = real_cpu.size(0) 179 | if opt.cuda: 180 | real_cpu = real_cpu.cuda() 181 | input.data.resize_as_(real_cpu).copy_(real_cpu) 182 | dis_label.data.resize_(batch_size).fill_(real_label) 183 | aux_label.data.resize_(batch_size).copy_(label) 184 | dis_output, aux_output = netD(input) 185 | 186 | dis_errD_real = dis_criterion(dis_output, dis_label) 187 | aux_errD_real = aux_criterion(aux_output, aux_label) 188 | errD_real = dis_errD_real + aux_errD_real 189 | errD_real.backward() 190 | D_x = dis_output.data.mean() 191 | 192 | # compute the current classification accuracy 193 | accuracy = compute_acc(aux_output, aux_label) 194 | 195 | # train with fake 196 | noise.data.resize_(batch_size, nz, 1, 1).normal_(0, 1) 197 | label = np.random.randint(0, num_classes, batch_size) 198 | noise_ = np.random.normal(0, 1, (batch_size, nz)) 199 | class_onehot = np.zeros((batch_size, num_classes)) 200 | class_onehot[np.arange(batch_size), label] = 1 201 | noise_[np.arange(batch_size), :num_classes] = class_onehot[np.arange(batch_size)] 202 | noise_ = (torch.from_numpy(noise_)) 203 | noise.data.copy_(noise_.view(batch_size, nz, 1, 1)) 204 | aux_label.data.resize_(batch_size).copy_(torch.from_numpy(label)) 205 | 206 | fake = netG(noise) 207 | dis_label.data.fill_(fake_label) 208 | dis_output, aux_output = netD(fake.detach()) 209 | dis_errD_fake = dis_criterion(dis_output, dis_label) 210 | aux_errD_fake = aux_criterion(aux_output, aux_label) 211 | errD_fake = dis_errD_fake + aux_errD_fake 212 | errD_fake.backward() 213 | D_G_z1 = dis_output.data.mean() 214 | errD = errD_real + errD_fake 215 | optimizerD.step() 216 | 217 | ############################ 218 | # (2) Update G network: maximize log(D(G(z))) 219 | ########################### 220 | netG.zero_grad() 221 | dis_label.data.fill_(real_label) # fake labels are real for generator cost 222 | dis_output, aux_output = netD(fake) 223 | dis_errG = dis_criterion(dis_output, dis_label) 224 | aux_errG = aux_criterion(aux_output, aux_label) 225 | errG = dis_errG + aux_errG 226 | errG.backward() 227 | D_G_z2 = dis_output.data.mean() 228 | optimizerG.step() 229 | 230 | # compute the average loss 231 | curr_iter = epoch * len(dataloader) + i 232 | all_loss_G = avg_loss_G * curr_iter 233 | all_loss_D = avg_loss_D * curr_iter 234 | all_loss_A = avg_loss_A * curr_iter 235 | all_loss_G += errG.data[0] 236 | all_loss_D += errD.data[0] 237 | all_loss_A += accuracy 238 | avg_loss_G = all_loss_G / (curr_iter + 1) 239 | avg_loss_D = all_loss_D / (curr_iter + 1) 240 | avg_loss_A = all_loss_A / (curr_iter + 1) 241 | 242 | print('[%d/%d][%d/%d] Loss_D: %.4f (%.4f) Loss_G: %.4f (%.4f) D(x): %.4f D(G(z)): %.4f / %.4f Acc: %.4f (%.4f)' 243 | % (epoch, opt.niter, i, len(dataloader), 244 | errD.data[0], avg_loss_D, errG.data[0], avg_loss_G, D_x, D_G_z1, D_G_z2, accuracy, avg_loss_A)) 245 | if i % 100 == 0: 246 | vutils.save_image( 247 | real_cpu, '%s/real_samples.png' % opt.outf) 248 | print('Label for eval = {}'.format(eval_label)) 249 | fake = netG(eval_noise) 250 | vutils.save_image( 251 | fake.data, 252 | '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch) 253 | ) 254 | 255 | # do checkpointing 256 | torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch)) 257 | torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch)) 258 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class _netG(nn.Module): 6 | def __init__(self, ngpu, nz): 7 | super(_netG, self).__init__() 8 | self.ngpu = ngpu 9 | self.nz = nz 10 | 11 | # first linear layer 12 | self.fc1 = nn.Linear(110, 768) 13 | # Transposed Convolution 2 14 | self.tconv2 = nn.Sequential( 15 | nn.ConvTranspose2d(768, 384, 5, 2, 0, bias=False), 16 | nn.BatchNorm2d(384), 17 | nn.ReLU(True), 18 | ) 19 | # Transposed Convolution 3 20 | self.tconv3 = nn.Sequential( 21 | nn.ConvTranspose2d(384, 256, 5, 2, 0, bias=False), 22 | nn.BatchNorm2d(256), 23 | nn.ReLU(True), 24 | ) 25 | # Transposed Convolution 4 26 | self.tconv4 = nn.Sequential( 27 | nn.ConvTranspose2d(256, 192, 5, 2, 0, bias=False), 28 | nn.BatchNorm2d(192), 29 | nn.ReLU(True), 30 | ) 31 | # Transposed Convolution 5 32 | self.tconv5 = nn.Sequential( 33 | nn.ConvTranspose2d(192, 64, 5, 2, 0, bias=False), 34 | nn.BatchNorm2d(64), 35 | nn.ReLU(True), 36 | ) 37 | # Transposed Convolution 5 38 | self.tconv6 = nn.Sequential( 39 | nn.ConvTranspose2d(64, 3, 8, 2, 0, bias=False), 40 | nn.Tanh(), 41 | ) 42 | 43 | def forward(self, input): 44 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 45 | input = input.view(-1, self.nz) 46 | fc1 = nn.parallel.data_parallel(self.fc1, input, range(self.ngpu)) 47 | fc1 = fc1.view(-1, 768, 1, 1) 48 | tconv2 = nn.parallel.data_parallel(self.tconv2, fc1, range(self.ngpu)) 49 | tconv3 = nn.parallel.data_parallel(self.tconv3, tconv2, range(self.ngpu)) 50 | tconv4 = nn.parallel.data_parallel(self.tconv4, tconv3, range(self.ngpu)) 51 | tconv5 = nn.parallel.data_parallel(self.tconv5, tconv4, range(self.ngpu)) 52 | tconv5 = nn.parallel.data_parallel(self.tconv6, tconv5, range(self.ngpu)) 53 | output = tconv5 54 | else: 55 | input = input.view(-1, self.nz) 56 | fc1 = self.fc1(input) 57 | fc1 = fc1.view(-1, 768, 1, 1) 58 | tconv2 = self.tconv2(fc1) 59 | tconv3 = self.tconv3(tconv2) 60 | tconv4 = self.tconv4(tconv3) 61 | tconv5 = self.tconv5(tconv4) 62 | tconv5 = self.tconv6(tconv5) 63 | output = tconv5 64 | return output 65 | 66 | 67 | class _netD(nn.Module): 68 | def __init__(self, ngpu, num_classes=10): 69 | super(_netD, self).__init__() 70 | self.ngpu = ngpu 71 | 72 | # Convolution 1 73 | self.conv1 = nn.Sequential( 74 | nn.Conv2d(3, 16, 3, 2, 1, bias=False), 75 | nn.LeakyReLU(0.2, inplace=True), 76 | nn.Dropout(0.5, inplace=False), 77 | ) 78 | # Convolution 2 79 | self.conv2 = nn.Sequential( 80 | nn.Conv2d(16, 32, 3, 1, 0, bias=False), 81 | nn.BatchNorm2d(32), 82 | nn.LeakyReLU(0.2, inplace=True), 83 | nn.Dropout(0.5, inplace=False), 84 | ) 85 | # Convolution 3 86 | self.conv3 = nn.Sequential( 87 | nn.Conv2d(32, 64, 3, 2, 1, bias=False), 88 | nn.BatchNorm2d(64), 89 | nn.LeakyReLU(0.2, inplace=True), 90 | nn.Dropout(0.5, inplace=False), 91 | ) 92 | # Convolution 4 93 | self.conv4 = nn.Sequential( 94 | nn.Conv2d(64, 128, 3, 1, 0, bias=False), 95 | nn.BatchNorm2d(128), 96 | nn.LeakyReLU(0.2, inplace=True), 97 | nn.Dropout(0.5, inplace=False), 98 | ) 99 | # Convolution 5 100 | self.conv5 = nn.Sequential( 101 | nn.Conv2d(128, 256, 3, 2, 1, bias=False), 102 | nn.BatchNorm2d(256), 103 | nn.LeakyReLU(0.2, inplace=True), 104 | nn.Dropout(0.5, inplace=False), 105 | ) 106 | # Convolution 6 107 | self.conv6 = nn.Sequential( 108 | nn.Conv2d(256, 512, 3, 1, 0, bias=False), 109 | nn.BatchNorm2d(512), 110 | nn.LeakyReLU(0.2, inplace=True), 111 | nn.Dropout(0.5, inplace=False), 112 | ) 113 | # discriminator fc 114 | self.fc_dis = nn.Linear(13*13*512, 1) 115 | # aux-classifier fc 116 | self.fc_aux = nn.Linear(13*13*512, num_classes) 117 | # softmax and sigmoid 118 | self.softmax = nn.Softmax() 119 | self.sigmoid = nn.Sigmoid() 120 | 121 | def forward(self, input): 122 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 123 | conv1 = nn.parallel.data_parallel(self.conv1, input, range(self.ngpu)) 124 | conv2 = nn.parallel.data_parallel(self.conv2, conv1, range(self.ngpu)) 125 | conv3 = nn.parallel.data_parallel(self.conv3, conv2, range(self.ngpu)) 126 | conv4 = nn.parallel.data_parallel(self.conv4, conv3, range(self.ngpu)) 127 | conv5 = nn.parallel.data_parallel(self.conv5, conv4, range(self.ngpu)) 128 | conv6 = nn.parallel.data_parallel(self.conv6, conv5, range(self.ngpu)) 129 | flat6 = conv6.view(-1, 13*13*512) 130 | fc_dis = nn.parallel.data_parallel(self.fc_dis, flat6, range(self.ngpu)) 131 | fc_aux = nn.parallel.data_parallel(self.fc_aux, flat6, range(self.ngpu)) 132 | else: 133 | conv1 = self.conv1(input) 134 | conv2 = self.conv2(conv1) 135 | conv3 = self.conv3(conv2) 136 | conv4 = self.conv4(conv3) 137 | conv5 = self.conv5(conv4) 138 | conv6 = self.conv6(conv5) 139 | flat6 = conv6.view(-1, 13*13*512) 140 | fc_dis = self.fc_dis(flat6) 141 | fc_aux = self.fc_aux(flat6) 142 | classes = self.softmax(fc_aux) 143 | realfake = self.sigmoid(fc_dis).view(-1, 1).squeeze(1) 144 | return realfake, classes 145 | 146 | 147 | class _netG_CIFAR10(nn.Module): 148 | def __init__(self, ngpu, nz): 149 | super(_netG_CIFAR10, self).__init__() 150 | self.ngpu = ngpu 151 | self.nz = nz 152 | 153 | # first linear layer 154 | self.fc1 = nn.Linear(110, 384) 155 | # Transposed Convolution 2 156 | self.tconv2 = nn.Sequential( 157 | nn.ConvTranspose2d(384, 192, 4, 1, 0, bias=False), 158 | nn.BatchNorm2d(192), 159 | nn.ReLU(True), 160 | ) 161 | # Transposed Convolution 3 162 | self.tconv3 = nn.Sequential( 163 | nn.ConvTranspose2d(192, 96, 4, 2, 1, bias=False), 164 | nn.BatchNorm2d(96), 165 | nn.ReLU(True), 166 | ) 167 | # Transposed Convolution 4 168 | self.tconv4 = nn.Sequential( 169 | nn.ConvTranspose2d(96, 48, 4, 2, 1, bias=False), 170 | nn.BatchNorm2d(48), 171 | nn.ReLU(True), 172 | ) 173 | # Transposed Convolution 4 174 | self.tconv5 = nn.Sequential( 175 | nn.ConvTranspose2d(48, 3, 4, 2, 1, bias=False), 176 | nn.Tanh(), 177 | ) 178 | 179 | def forward(self, input): 180 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 181 | input = input.view(-1, self.nz) 182 | fc1 = nn.parallel.data_parallel(self.fc1, input, range(self.ngpu)) 183 | fc1 = fc1.view(-1, 384, 1, 1) 184 | tconv2 = nn.parallel.data_parallel(self.tconv2, fc1, range(self.ngpu)) 185 | tconv3 = nn.parallel.data_parallel(self.tconv3, tconv2, range(self.ngpu)) 186 | tconv4 = nn.parallel.data_parallel(self.tconv4, tconv3, range(self.ngpu)) 187 | tconv5 = nn.parallel.data_parallel(self.tconv5, tconv4, range(self.ngpu)) 188 | output = tconv5 189 | else: 190 | input = input.view(-1, self.nz) 191 | fc1 = self.fc1(input) 192 | fc1 = fc1.view(-1, 384, 1, 1) 193 | tconv2 = self.tconv2(fc1) 194 | tconv3 = self.tconv3(tconv2) 195 | tconv4 = self.tconv4(tconv3) 196 | tconv5 = self.tconv5(tconv4) 197 | output = tconv5 198 | return output 199 | 200 | 201 | class _netD_CIFAR10(nn.Module): 202 | def __init__(self, ngpu, num_classes=10): 203 | super(_netD_CIFAR10, self).__init__() 204 | self.ngpu = ngpu 205 | 206 | # Convolution 1 207 | self.conv1 = nn.Sequential( 208 | nn.Conv2d(3, 16, 3, 2, 1, bias=False), 209 | nn.LeakyReLU(0.2, inplace=True), 210 | nn.Dropout(0.5, inplace=False), 211 | ) 212 | # Convolution 2 213 | self.conv2 = nn.Sequential( 214 | nn.Conv2d(16, 32, 3, 1, 1, bias=False), 215 | nn.BatchNorm2d(32), 216 | nn.LeakyReLU(0.2, inplace=True), 217 | nn.Dropout(0.5, inplace=False), 218 | ) 219 | # Convolution 3 220 | self.conv3 = nn.Sequential( 221 | nn.Conv2d(32, 64, 3, 2, 1, bias=False), 222 | nn.BatchNorm2d(64), 223 | nn.LeakyReLU(0.2, inplace=True), 224 | nn.Dropout(0.5, inplace=False), 225 | ) 226 | # Convolution 4 227 | self.conv4 = nn.Sequential( 228 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), 229 | nn.BatchNorm2d(128), 230 | nn.LeakyReLU(0.2, inplace=True), 231 | nn.Dropout(0.5, inplace=False), 232 | ) 233 | # Convolution 5 234 | self.conv5 = nn.Sequential( 235 | nn.Conv2d(128, 256, 3, 2, 1, bias=False), 236 | nn.BatchNorm2d(256), 237 | nn.LeakyReLU(0.2, inplace=True), 238 | nn.Dropout(0.5, inplace=False), 239 | ) 240 | # Convolution 6 241 | self.conv6 = nn.Sequential( 242 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), 243 | nn.BatchNorm2d(512), 244 | nn.LeakyReLU(0.2, inplace=True), 245 | nn.Dropout(0.5, inplace=False), 246 | ) 247 | # discriminator fc 248 | self.fc_dis = nn.Linear(4*4*512, 1) 249 | # aux-classifier fc 250 | self.fc_aux = nn.Linear(4*4*512, num_classes) 251 | # softmax and sigmoid 252 | self.softmax = nn.Softmax() 253 | self.sigmoid = nn.Sigmoid() 254 | 255 | def forward(self, input): 256 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 257 | conv1 = nn.parallel.data_parallel(self.conv1, input, range(self.ngpu)) 258 | conv2 = nn.parallel.data_parallel(self.conv2, conv1, range(self.ngpu)) 259 | conv3 = nn.parallel.data_parallel(self.conv3, conv2, range(self.ngpu)) 260 | conv4 = nn.parallel.data_parallel(self.conv4, conv3, range(self.ngpu)) 261 | conv5 = nn.parallel.data_parallel(self.conv5, conv4, range(self.ngpu)) 262 | conv6 = nn.parallel.data_parallel(self.conv6, conv5, range(self.ngpu)) 263 | flat6 = conv6.view(-1, 4*4*512) 264 | fc_dis = nn.parallel.data_parallel(self.fc_dis, flat6, range(self.ngpu)) 265 | fc_aux = nn.parallel.data_parallel(self.fc_aux, flat6, range(self.ngpu)) 266 | else: 267 | conv1 = self.conv1(input) 268 | conv2 = self.conv2(conv1) 269 | conv3 = self.conv3(conv2) 270 | conv4 = self.conv4(conv3) 271 | conv5 = self.conv5(conv4) 272 | conv6 = self.conv6(conv5) 273 | flat6 = conv6.view(-1, 4*4*512) 274 | fc_dis = self.fc_dis(flat6) 275 | fc_aux = self.fc_aux(flat6) 276 | classes = self.softmax(fc_aux) 277 | realfake = self.sigmoid(fc_dis).view(-1, 1).squeeze(1) 278 | return realfake, classes 279 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # custom weights initialization called on netG and netD 2 | def weights_init(m): 3 | classname = m.__class__.__name__ 4 | if classname.find('Conv') != -1: 5 | m.weight.data.normal_(0.0, 0.02) 6 | elif classname.find('BatchNorm') != -1: 7 | m.weight.data.normal_(1.0, 0.02) 8 | m.bias.data.fill_(0) 9 | 10 | # compute the current classification accuracy 11 | def compute_acc(preds, labels): 12 | correct = 0 13 | preds_ = preds.data.max(1)[1] 14 | correct = preds_.eq(labels.data).cpu().sum() 15 | acc = float(correct) / float(len(labels.data)) * 100.0 16 | return acc 17 | --------------------------------------------------------------------------------