├── .gitignore
├── LICENSE
├── README.md
├── figs
├── architecture.png
├── fake_samples_epoch_470.png
└── fake_samples_epoch_499.png
├── folder.py
├── main.py
├── network.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | outputs*
2 | *.pyc
3 | *.sw[po]
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Te-Lin Wu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Conditional Image Synthesis With Auxiliary Classifier GANs
2 |
3 | As part of the implementation series of [Joseph Lim's group at USC](http://csail.mit.edu/~lim), our motivation is to accelerate (or sometimes delay) research in the AI community by promoting open-source projects. To this end, we implement state-of-the-art research papers, and publicly share them with concise reports. Please visit our [group github site](https://github.com/gitlimlab) for other projects.
4 |
5 | This project is implemented by [Te-Lin Wu](https://github.com/telin0411) and the codes have been reviewed by [Shao-Hua Sun](https://github.com/shaohua0116) before being published.
6 |
7 | ## Descriptions
8 | This project is a [PyTorch](http://pytorch.org) implementation of [Conditional Image Synthesis With Auxiliary Classifier GANs](https://arxiv.org/abs/1610.09585) which was published as a conference proceeding at ICML 2017. This paper proposes a simple extention of GANs that employs label conditioning in additional to produce high resolution and high quality generated images.
9 |
10 | By adding an auxiliary classifier to the discriminator of a GAN, the discriminator produces not only a probability distribution over sources but also probability distribution over the class labels. This simple modification to the standard DCGAN models does not give tremendous difference but produces better results and is capable of stabilizing the whole adversarial training.
11 |
12 | The architecture is as shown below for comparisons of several GANs.
13 |
14 |
15 |
16 |
17 | The sample generated images from ImageNet dataset.
18 |
19 |
20 |
21 |
22 | The sample generated images from CIFAR-10 dataset.
23 |
24 |
25 |
26 |
27 | The implemented model can be trained on both [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) and [ImageNet](http://www.image-net.org) datasets.
28 |
29 | Note that this implementation may differ in details from the original paper such as model architectures, hyperparameters, applied optimizer, etc. while maintaining the main proposed idea.
30 |
31 | \*This code is still being developed and subject to change.
32 |
33 | ## Prerequisites
34 |
35 | - Python 2.7
36 | - [PyTorch](http://pytorch.org)
37 | - [SciPy](http://www.scipy.org/install.html)
38 | - [NumPy](http://www.numpy.org/)
39 | - [PIL](http://pillow.readthedocs.io/en/3.1.x/installation.html)
40 | - [imageio](https://imageio.github.io/)
41 |
42 | ## Usage
43 | Run the following command for details of each arguments.
44 | ```bash
45 | $ python main.py -h
46 | ```
47 | You should specify the path to the dataset you are using with argument --dataroot, the code will automatically check if you have cifar10 dataset downloaded or not. If not, the code will download it for you. For the ImageNet training you should download the whole dataset on their website, this repository used 2012 version for the training. And you should point the dataroot to the train (or val) directory as the root directory for ImageNet training.
48 |
49 | In line 80 of main.py, you can change the classes\_idx argument to take into other user-specified imagenet classes, and adjust the num\_classes accordingly if it is not 10.
50 | ```python
51 | if opt.dataset == 'imagenet':
52 | # folder dataset
53 | dataset = ImageFolder(root=opt.dataroot,
54 | transform=transforms.Compose([
55 | transforms.Scale(opt.imageSize),
56 | transforms.CenterCrop(opt.imageSize),
57 | transforms.ToTensor(),
58 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
59 | ]),
60 | classes_idx=(10,20))
61 | ```
62 |
63 | ### Train the models
64 | Example training commands, the code will automatically generate images for testing during training to the --outf directory.
65 | ```bash
66 | $ python main.py --outf=/your/output/file/name --niter=500 --batchSize=100 --cuda --dataset=cifar10 --imageSize=32 --dataroot=/data/path/to/cifar10 --gpu=0
67 | ```
68 |
69 | ## Author
70 |
71 | Te-Lin Wu / [@telin0411](https://github.com/telin0411) @ [Joseph Lim's research lab](https://github.com/gitlimlab) @ USC
72 |
--------------------------------------------------------------------------------
/figs/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/ACGAN-PyTorch/4bd0405a4bd90b07548d4b84b57e24767d7cbf65/figs/architecture.png
--------------------------------------------------------------------------------
/figs/fake_samples_epoch_470.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/ACGAN-PyTorch/4bd0405a4bd90b07548d4b84b57e24767d7cbf65/figs/fake_samples_epoch_470.png
--------------------------------------------------------------------------------
/figs/fake_samples_epoch_499.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/ACGAN-PyTorch/4bd0405a4bd90b07548d4b84b57e24767d7cbf65/figs/fake_samples_epoch_499.png
--------------------------------------------------------------------------------
/folder.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 |
7 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
8 |
9 |
10 | def is_image_file(filename):
11 | """Checks if a file is an image.
12 |
13 | Args:
14 | filename (string): path to a file
15 |
16 | Returns:
17 | bool: True if the filename ends with a known image extension
18 | """
19 | filename_lower = filename.lower()
20 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)
21 |
22 |
23 | def find_classes(dir, classes_idx=None):
24 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
25 | classes.sort()
26 | if classes_idx is not None:
27 | assert type(classes_idx) == tuple
28 | start, end = classes_idx
29 | classes = classes[start:end]
30 | class_to_idx = {classes[i]: i for i in range(len(classes))}
31 | return classes, class_to_idx
32 |
33 |
34 | def make_dataset(dir, class_to_idx):
35 | images = []
36 | dir = os.path.expanduser(dir)
37 | for target in sorted(os.listdir(dir)):
38 | if target not in class_to_idx:
39 | continue
40 | d = os.path.join(dir, target)
41 | if not os.path.isdir(d):
42 | continue
43 |
44 | for root, _, fnames in sorted(os.walk(d)):
45 | for fname in sorted(fnames):
46 | if is_image_file(fname):
47 | path = os.path.join(root, fname)
48 | item = (path, class_to_idx[target])
49 | images.append(item)
50 |
51 | return images
52 |
53 |
54 | def pil_loader(path):
55 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
56 | with open(path, 'rb') as f:
57 | with Image.open(f) as img:
58 | return img.convert('RGB')
59 |
60 |
61 | def accimage_loader(path):
62 | import accimage
63 | try:
64 | return accimage.Image(path)
65 | except IOError:
66 | # Potentially a decoding problem, fall back to PIL.Image
67 | return pil_loader(path)
68 |
69 |
70 | def default_loader(path):
71 | from torchvision import get_image_backend
72 | if get_image_backend() == 'accimage':
73 | return accimage_loader(path)
74 | else:
75 | return pil_loader(path)
76 |
77 |
78 | class ImageFolder(data.Dataset):
79 | """A generic data loader where the images are arranged in this way: ::
80 |
81 | root/dog/xxx.png
82 | root/dog/xxy.png
83 | root/dog/xxz.png
84 |
85 | root/cat/123.png
86 | root/cat/nsdf3.png
87 | root/cat/asd932_.png
88 |
89 | Args:
90 | root (string): Root directory path.
91 | transform (callable, optional): A function/transform that takes in an PIL image
92 | and returns a transformed version. E.g, ``transforms.RandomCrop``
93 | target_transform (callable, optional): A function/transform that takes in the
94 | target and transforms it.
95 | loader (callable, optional): A function to load an image given its path.
96 |
97 | Attributes:
98 | classes (list): List of the class names.
99 | class_to_idx (dict): Dict with items (class_name, class_index).
100 | imgs (list): List of (image path, class_index) tuples
101 | """
102 |
103 | def __init__(self, root, transform=None, target_transform=None,
104 | loader=default_loader, classes_idx=None):
105 | self.classes_idx = classes_idx
106 | classes, class_to_idx = find_classes(root, self.classes_idx)
107 | imgs = make_dataset(root, class_to_idx)
108 | if len(imgs) == 0:
109 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
110 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
111 |
112 | self.root = root
113 | self.imgs = imgs
114 | self.classes = classes
115 | self.class_to_idx = class_to_idx
116 | self.transform = transform
117 | self.target_transform = target_transform
118 | self.loader = loader
119 |
120 | def __getitem__(self, index):
121 | """
122 | Args:
123 | index (int): Index
124 |
125 | Returns:
126 | tuple: (image, target) where target is class_index of the target class.
127 | """
128 | path, target = self.imgs[index]
129 | img = self.loader(path)
130 | if self.transform is not None:
131 | img = self.transform(img)
132 | if self.target_transform is not None:
133 | target = self.target_transform(target)
134 |
135 | return img, target
136 |
137 | def __len__(self):
138 | return len(self.imgs)
139 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Code modified from PyTorch DCGAN examples: https://github.com/pytorch/examples/tree/master/dcgan
3 | """
4 | from __future__ import print_function
5 | import argparse
6 | import os
7 | import numpy as np
8 | import random
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.parallel
12 | import torch.backends.cudnn as cudnn
13 | import torch.optim as optim
14 | import torch.utils.data
15 | import torchvision.datasets as dset
16 | import torchvision.transforms as transforms
17 | import torchvision.utils as vutils
18 | from torch.autograd import Variable
19 | from utils import weights_init, compute_acc
20 | from network import _netG, _netD, _netD_CIFAR10, _netG_CIFAR10
21 | from folder import ImageFolder
22 |
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--dataset', required=True, help='cifar10 | imagenet')
26 | parser.add_argument('--dataroot', required=True, help='path to dataset')
27 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
28 | parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
29 | parser.add_argument('--imageSize', type=int, default=128, help='the height / width of the input image to network')
30 | parser.add_argument('--nz', type=int, default=110, help='size of the latent z vector')
31 | parser.add_argument('--ngf', type=int, default=64)
32 | parser.add_argument('--ndf', type=int, default=64)
33 | parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
34 | parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
35 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
36 | parser.add_argument('--cuda', action='store_true', help='enables cuda')
37 | parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
38 | parser.add_argument('--netG', default='', help="path to netG (to continue training)")
39 | parser.add_argument('--netD', default='', help="path to netD (to continue training)")
40 | parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')
41 | parser.add_argument('--manualSeed', type=int, help='manual seed')
42 | parser.add_argument('--num_classes', type=int, default=10, help='Number of classes for AC-GAN')
43 | parser.add_argument('--gpu_id', type=int, default=0, help='The ID of the specified GPU')
44 |
45 | opt = parser.parse_args()
46 | print(opt)
47 |
48 | # specify the gpu id if using only 1 gpu
49 | if opt.ngpu == 1:
50 | os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu_id)
51 |
52 | try:
53 | os.makedirs(opt.outf)
54 | except OSError:
55 | pass
56 |
57 | if opt.manualSeed is None:
58 | opt.manualSeed = random.randint(1, 10000)
59 | print("Random Seed: ", opt.manualSeed)
60 | random.seed(opt.manualSeed)
61 | torch.manual_seed(opt.manualSeed)
62 | if opt.cuda:
63 | torch.cuda.manual_seed_all(opt.manualSeed)
64 |
65 | cudnn.benchmark = True
66 |
67 | if torch.cuda.is_available() and not opt.cuda:
68 | print("WARNING: You have a CUDA device, so you should probably run with --cuda")
69 |
70 | # datase t
71 | if opt.dataset == 'imagenet':
72 | # folder dataset
73 | dataset = ImageFolder(
74 | root=opt.dataroot,
75 | transform=transforms.Compose([
76 | transforms.Scale(opt.imageSize),
77 | transforms.CenterCrop(opt.imageSize),
78 | transforms.ToTensor(),
79 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
80 | ]),
81 | classes_idx=(10, 20)
82 | )
83 | elif opt.dataset == 'cifar10':
84 | dataset = dset.CIFAR10(
85 | root=opt.dataroot, download=True,
86 | transform=transforms.Compose([
87 | transforms.Scale(opt.imageSize),
88 | transforms.ToTensor(),
89 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
90 | ]))
91 | else:
92 | raise NotImplementedError("No such dataset {}".format(opt.dataset))
93 |
94 | assert dataset
95 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
96 | shuffle=True, num_workers=int(opt.workers))
97 |
98 | # some hyper parameters
99 | ngpu = int(opt.ngpu)
100 | nz = int(opt.nz)
101 | ngf = int(opt.ngf)
102 | ndf = int(opt.ndf)
103 | num_classes = int(opt.num_classes)
104 | nc = 3
105 |
106 | # Define the generator and initialize the weights
107 | if opt.dataset == 'imagenet':
108 | netG = _netG(ngpu, nz)
109 | else:
110 | netG = _netG_CIFAR10(ngpu, nz)
111 | netG.apply(weights_init)
112 | if opt.netG != '':
113 | netG.load_state_dict(torch.load(opt.netG))
114 | print(netG)
115 |
116 | # Define the discriminator and initialize the weights
117 | if opt.dataset == 'imagenet':
118 | netD = _netD(ngpu, num_classes)
119 | else:
120 | netD = _netD_CIFAR10(ngpu, num_classes)
121 | netD.apply(weights_init)
122 | if opt.netD != '':
123 | netD.load_state_dict(torch.load(opt.netD))
124 | print(netD)
125 |
126 | # loss functions
127 | dis_criterion = nn.BCELoss()
128 | aux_criterion = nn.NLLLoss()
129 |
130 | # tensor placeholders
131 | input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
132 | noise = torch.FloatTensor(opt.batchSize, nz, 1, 1)
133 | eval_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1)
134 | dis_label = torch.FloatTensor(opt.batchSize)
135 | aux_label = torch.LongTensor(opt.batchSize)
136 | real_label = 1
137 | fake_label = 0
138 |
139 | # if using cuda
140 | if opt.cuda:
141 | netD.cuda()
142 | netG.cuda()
143 | dis_criterion.cuda()
144 | aux_criterion.cuda()
145 | input, dis_label, aux_label = input.cuda(), dis_label.cuda(), aux_label.cuda()
146 | noise, eval_noise = noise.cuda(), eval_noise.cuda()
147 |
148 | # define variables
149 | input = Variable(input)
150 | noise = Variable(noise)
151 | eval_noise = Variable(eval_noise)
152 | dis_label = Variable(dis_label)
153 | aux_label = Variable(aux_label)
154 | # noise for evaluation
155 | eval_noise_ = np.random.normal(0, 1, (opt.batchSize, nz))
156 | eval_label = np.random.randint(0, num_classes, opt.batchSize)
157 | eval_onehot = np.zeros((opt.batchSize, num_classes))
158 | eval_onehot[np.arange(opt.batchSize), eval_label] = 1
159 | eval_noise_[np.arange(opt.batchSize), :num_classes] = eval_onehot[np.arange(opt.batchSize)]
160 | eval_noise_ = (torch.from_numpy(eval_noise_))
161 | eval_noise.data.copy_(eval_noise_.view(opt.batchSize, nz, 1, 1))
162 |
163 | # setup optimizer
164 | optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
165 | optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
166 |
167 | avg_loss_D = 0.0
168 | avg_loss_G = 0.0
169 | avg_loss_A = 0.0
170 | for epoch in range(opt.niter):
171 | for i, data in enumerate(dataloader, 0):
172 | ############################
173 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
174 | ###########################
175 | # train with real
176 | netD.zero_grad()
177 | real_cpu, label = data
178 | batch_size = real_cpu.size(0)
179 | if opt.cuda:
180 | real_cpu = real_cpu.cuda()
181 | input.data.resize_as_(real_cpu).copy_(real_cpu)
182 | dis_label.data.resize_(batch_size).fill_(real_label)
183 | aux_label.data.resize_(batch_size).copy_(label)
184 | dis_output, aux_output = netD(input)
185 |
186 | dis_errD_real = dis_criterion(dis_output, dis_label)
187 | aux_errD_real = aux_criterion(aux_output, aux_label)
188 | errD_real = dis_errD_real + aux_errD_real
189 | errD_real.backward()
190 | D_x = dis_output.data.mean()
191 |
192 | # compute the current classification accuracy
193 | accuracy = compute_acc(aux_output, aux_label)
194 |
195 | # train with fake
196 | noise.data.resize_(batch_size, nz, 1, 1).normal_(0, 1)
197 | label = np.random.randint(0, num_classes, batch_size)
198 | noise_ = np.random.normal(0, 1, (batch_size, nz))
199 | class_onehot = np.zeros((batch_size, num_classes))
200 | class_onehot[np.arange(batch_size), label] = 1
201 | noise_[np.arange(batch_size), :num_classes] = class_onehot[np.arange(batch_size)]
202 | noise_ = (torch.from_numpy(noise_))
203 | noise.data.copy_(noise_.view(batch_size, nz, 1, 1))
204 | aux_label.data.resize_(batch_size).copy_(torch.from_numpy(label))
205 |
206 | fake = netG(noise)
207 | dis_label.data.fill_(fake_label)
208 | dis_output, aux_output = netD(fake.detach())
209 | dis_errD_fake = dis_criterion(dis_output, dis_label)
210 | aux_errD_fake = aux_criterion(aux_output, aux_label)
211 | errD_fake = dis_errD_fake + aux_errD_fake
212 | errD_fake.backward()
213 | D_G_z1 = dis_output.data.mean()
214 | errD = errD_real + errD_fake
215 | optimizerD.step()
216 |
217 | ############################
218 | # (2) Update G network: maximize log(D(G(z)))
219 | ###########################
220 | netG.zero_grad()
221 | dis_label.data.fill_(real_label) # fake labels are real for generator cost
222 | dis_output, aux_output = netD(fake)
223 | dis_errG = dis_criterion(dis_output, dis_label)
224 | aux_errG = aux_criterion(aux_output, aux_label)
225 | errG = dis_errG + aux_errG
226 | errG.backward()
227 | D_G_z2 = dis_output.data.mean()
228 | optimizerG.step()
229 |
230 | # compute the average loss
231 | curr_iter = epoch * len(dataloader) + i
232 | all_loss_G = avg_loss_G * curr_iter
233 | all_loss_D = avg_loss_D * curr_iter
234 | all_loss_A = avg_loss_A * curr_iter
235 | all_loss_G += errG.data[0]
236 | all_loss_D += errD.data[0]
237 | all_loss_A += accuracy
238 | avg_loss_G = all_loss_G / (curr_iter + 1)
239 | avg_loss_D = all_loss_D / (curr_iter + 1)
240 | avg_loss_A = all_loss_A / (curr_iter + 1)
241 |
242 | print('[%d/%d][%d/%d] Loss_D: %.4f (%.4f) Loss_G: %.4f (%.4f) D(x): %.4f D(G(z)): %.4f / %.4f Acc: %.4f (%.4f)'
243 | % (epoch, opt.niter, i, len(dataloader),
244 | errD.data[0], avg_loss_D, errG.data[0], avg_loss_G, D_x, D_G_z1, D_G_z2, accuracy, avg_loss_A))
245 | if i % 100 == 0:
246 | vutils.save_image(
247 | real_cpu, '%s/real_samples.png' % opt.outf)
248 | print('Label for eval = {}'.format(eval_label))
249 | fake = netG(eval_noise)
250 | vutils.save_image(
251 | fake.data,
252 | '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch)
253 | )
254 |
255 | # do checkpointing
256 | torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
257 | torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))
258 |
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class _netG(nn.Module):
6 | def __init__(self, ngpu, nz):
7 | super(_netG, self).__init__()
8 | self.ngpu = ngpu
9 | self.nz = nz
10 |
11 | # first linear layer
12 | self.fc1 = nn.Linear(110, 768)
13 | # Transposed Convolution 2
14 | self.tconv2 = nn.Sequential(
15 | nn.ConvTranspose2d(768, 384, 5, 2, 0, bias=False),
16 | nn.BatchNorm2d(384),
17 | nn.ReLU(True),
18 | )
19 | # Transposed Convolution 3
20 | self.tconv3 = nn.Sequential(
21 | nn.ConvTranspose2d(384, 256, 5, 2, 0, bias=False),
22 | nn.BatchNorm2d(256),
23 | nn.ReLU(True),
24 | )
25 | # Transposed Convolution 4
26 | self.tconv4 = nn.Sequential(
27 | nn.ConvTranspose2d(256, 192, 5, 2, 0, bias=False),
28 | nn.BatchNorm2d(192),
29 | nn.ReLU(True),
30 | )
31 | # Transposed Convolution 5
32 | self.tconv5 = nn.Sequential(
33 | nn.ConvTranspose2d(192, 64, 5, 2, 0, bias=False),
34 | nn.BatchNorm2d(64),
35 | nn.ReLU(True),
36 | )
37 | # Transposed Convolution 5
38 | self.tconv6 = nn.Sequential(
39 | nn.ConvTranspose2d(64, 3, 8, 2, 0, bias=False),
40 | nn.Tanh(),
41 | )
42 |
43 | def forward(self, input):
44 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
45 | input = input.view(-1, self.nz)
46 | fc1 = nn.parallel.data_parallel(self.fc1, input, range(self.ngpu))
47 | fc1 = fc1.view(-1, 768, 1, 1)
48 | tconv2 = nn.parallel.data_parallel(self.tconv2, fc1, range(self.ngpu))
49 | tconv3 = nn.parallel.data_parallel(self.tconv3, tconv2, range(self.ngpu))
50 | tconv4 = nn.parallel.data_parallel(self.tconv4, tconv3, range(self.ngpu))
51 | tconv5 = nn.parallel.data_parallel(self.tconv5, tconv4, range(self.ngpu))
52 | tconv5 = nn.parallel.data_parallel(self.tconv6, tconv5, range(self.ngpu))
53 | output = tconv5
54 | else:
55 | input = input.view(-1, self.nz)
56 | fc1 = self.fc1(input)
57 | fc1 = fc1.view(-1, 768, 1, 1)
58 | tconv2 = self.tconv2(fc1)
59 | tconv3 = self.tconv3(tconv2)
60 | tconv4 = self.tconv4(tconv3)
61 | tconv5 = self.tconv5(tconv4)
62 | tconv5 = self.tconv6(tconv5)
63 | output = tconv5
64 | return output
65 |
66 |
67 | class _netD(nn.Module):
68 | def __init__(self, ngpu, num_classes=10):
69 | super(_netD, self).__init__()
70 | self.ngpu = ngpu
71 |
72 | # Convolution 1
73 | self.conv1 = nn.Sequential(
74 | nn.Conv2d(3, 16, 3, 2, 1, bias=False),
75 | nn.LeakyReLU(0.2, inplace=True),
76 | nn.Dropout(0.5, inplace=False),
77 | )
78 | # Convolution 2
79 | self.conv2 = nn.Sequential(
80 | nn.Conv2d(16, 32, 3, 1, 0, bias=False),
81 | nn.BatchNorm2d(32),
82 | nn.LeakyReLU(0.2, inplace=True),
83 | nn.Dropout(0.5, inplace=False),
84 | )
85 | # Convolution 3
86 | self.conv3 = nn.Sequential(
87 | nn.Conv2d(32, 64, 3, 2, 1, bias=False),
88 | nn.BatchNorm2d(64),
89 | nn.LeakyReLU(0.2, inplace=True),
90 | nn.Dropout(0.5, inplace=False),
91 | )
92 | # Convolution 4
93 | self.conv4 = nn.Sequential(
94 | nn.Conv2d(64, 128, 3, 1, 0, bias=False),
95 | nn.BatchNorm2d(128),
96 | nn.LeakyReLU(0.2, inplace=True),
97 | nn.Dropout(0.5, inplace=False),
98 | )
99 | # Convolution 5
100 | self.conv5 = nn.Sequential(
101 | nn.Conv2d(128, 256, 3, 2, 1, bias=False),
102 | nn.BatchNorm2d(256),
103 | nn.LeakyReLU(0.2, inplace=True),
104 | nn.Dropout(0.5, inplace=False),
105 | )
106 | # Convolution 6
107 | self.conv6 = nn.Sequential(
108 | nn.Conv2d(256, 512, 3, 1, 0, bias=False),
109 | nn.BatchNorm2d(512),
110 | nn.LeakyReLU(0.2, inplace=True),
111 | nn.Dropout(0.5, inplace=False),
112 | )
113 | # discriminator fc
114 | self.fc_dis = nn.Linear(13*13*512, 1)
115 | # aux-classifier fc
116 | self.fc_aux = nn.Linear(13*13*512, num_classes)
117 | # softmax and sigmoid
118 | self.softmax = nn.Softmax()
119 | self.sigmoid = nn.Sigmoid()
120 |
121 | def forward(self, input):
122 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
123 | conv1 = nn.parallel.data_parallel(self.conv1, input, range(self.ngpu))
124 | conv2 = nn.parallel.data_parallel(self.conv2, conv1, range(self.ngpu))
125 | conv3 = nn.parallel.data_parallel(self.conv3, conv2, range(self.ngpu))
126 | conv4 = nn.parallel.data_parallel(self.conv4, conv3, range(self.ngpu))
127 | conv5 = nn.parallel.data_parallel(self.conv5, conv4, range(self.ngpu))
128 | conv6 = nn.parallel.data_parallel(self.conv6, conv5, range(self.ngpu))
129 | flat6 = conv6.view(-1, 13*13*512)
130 | fc_dis = nn.parallel.data_parallel(self.fc_dis, flat6, range(self.ngpu))
131 | fc_aux = nn.parallel.data_parallel(self.fc_aux, flat6, range(self.ngpu))
132 | else:
133 | conv1 = self.conv1(input)
134 | conv2 = self.conv2(conv1)
135 | conv3 = self.conv3(conv2)
136 | conv4 = self.conv4(conv3)
137 | conv5 = self.conv5(conv4)
138 | conv6 = self.conv6(conv5)
139 | flat6 = conv6.view(-1, 13*13*512)
140 | fc_dis = self.fc_dis(flat6)
141 | fc_aux = self.fc_aux(flat6)
142 | classes = self.softmax(fc_aux)
143 | realfake = self.sigmoid(fc_dis).view(-1, 1).squeeze(1)
144 | return realfake, classes
145 |
146 |
147 | class _netG_CIFAR10(nn.Module):
148 | def __init__(self, ngpu, nz):
149 | super(_netG_CIFAR10, self).__init__()
150 | self.ngpu = ngpu
151 | self.nz = nz
152 |
153 | # first linear layer
154 | self.fc1 = nn.Linear(110, 384)
155 | # Transposed Convolution 2
156 | self.tconv2 = nn.Sequential(
157 | nn.ConvTranspose2d(384, 192, 4, 1, 0, bias=False),
158 | nn.BatchNorm2d(192),
159 | nn.ReLU(True),
160 | )
161 | # Transposed Convolution 3
162 | self.tconv3 = nn.Sequential(
163 | nn.ConvTranspose2d(192, 96, 4, 2, 1, bias=False),
164 | nn.BatchNorm2d(96),
165 | nn.ReLU(True),
166 | )
167 | # Transposed Convolution 4
168 | self.tconv4 = nn.Sequential(
169 | nn.ConvTranspose2d(96, 48, 4, 2, 1, bias=False),
170 | nn.BatchNorm2d(48),
171 | nn.ReLU(True),
172 | )
173 | # Transposed Convolution 4
174 | self.tconv5 = nn.Sequential(
175 | nn.ConvTranspose2d(48, 3, 4, 2, 1, bias=False),
176 | nn.Tanh(),
177 | )
178 |
179 | def forward(self, input):
180 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
181 | input = input.view(-1, self.nz)
182 | fc1 = nn.parallel.data_parallel(self.fc1, input, range(self.ngpu))
183 | fc1 = fc1.view(-1, 384, 1, 1)
184 | tconv2 = nn.parallel.data_parallel(self.tconv2, fc1, range(self.ngpu))
185 | tconv3 = nn.parallel.data_parallel(self.tconv3, tconv2, range(self.ngpu))
186 | tconv4 = nn.parallel.data_parallel(self.tconv4, tconv3, range(self.ngpu))
187 | tconv5 = nn.parallel.data_parallel(self.tconv5, tconv4, range(self.ngpu))
188 | output = tconv5
189 | else:
190 | input = input.view(-1, self.nz)
191 | fc1 = self.fc1(input)
192 | fc1 = fc1.view(-1, 384, 1, 1)
193 | tconv2 = self.tconv2(fc1)
194 | tconv3 = self.tconv3(tconv2)
195 | tconv4 = self.tconv4(tconv3)
196 | tconv5 = self.tconv5(tconv4)
197 | output = tconv5
198 | return output
199 |
200 |
201 | class _netD_CIFAR10(nn.Module):
202 | def __init__(self, ngpu, num_classes=10):
203 | super(_netD_CIFAR10, self).__init__()
204 | self.ngpu = ngpu
205 |
206 | # Convolution 1
207 | self.conv1 = nn.Sequential(
208 | nn.Conv2d(3, 16, 3, 2, 1, bias=False),
209 | nn.LeakyReLU(0.2, inplace=True),
210 | nn.Dropout(0.5, inplace=False),
211 | )
212 | # Convolution 2
213 | self.conv2 = nn.Sequential(
214 | nn.Conv2d(16, 32, 3, 1, 1, bias=False),
215 | nn.BatchNorm2d(32),
216 | nn.LeakyReLU(0.2, inplace=True),
217 | nn.Dropout(0.5, inplace=False),
218 | )
219 | # Convolution 3
220 | self.conv3 = nn.Sequential(
221 | nn.Conv2d(32, 64, 3, 2, 1, bias=False),
222 | nn.BatchNorm2d(64),
223 | nn.LeakyReLU(0.2, inplace=True),
224 | nn.Dropout(0.5, inplace=False),
225 | )
226 | # Convolution 4
227 | self.conv4 = nn.Sequential(
228 | nn.Conv2d(64, 128, 3, 1, 1, bias=False),
229 | nn.BatchNorm2d(128),
230 | nn.LeakyReLU(0.2, inplace=True),
231 | nn.Dropout(0.5, inplace=False),
232 | )
233 | # Convolution 5
234 | self.conv5 = nn.Sequential(
235 | nn.Conv2d(128, 256, 3, 2, 1, bias=False),
236 | nn.BatchNorm2d(256),
237 | nn.LeakyReLU(0.2, inplace=True),
238 | nn.Dropout(0.5, inplace=False),
239 | )
240 | # Convolution 6
241 | self.conv6 = nn.Sequential(
242 | nn.Conv2d(256, 512, 3, 1, 1, bias=False),
243 | nn.BatchNorm2d(512),
244 | nn.LeakyReLU(0.2, inplace=True),
245 | nn.Dropout(0.5, inplace=False),
246 | )
247 | # discriminator fc
248 | self.fc_dis = nn.Linear(4*4*512, 1)
249 | # aux-classifier fc
250 | self.fc_aux = nn.Linear(4*4*512, num_classes)
251 | # softmax and sigmoid
252 | self.softmax = nn.Softmax()
253 | self.sigmoid = nn.Sigmoid()
254 |
255 | def forward(self, input):
256 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
257 | conv1 = nn.parallel.data_parallel(self.conv1, input, range(self.ngpu))
258 | conv2 = nn.parallel.data_parallel(self.conv2, conv1, range(self.ngpu))
259 | conv3 = nn.parallel.data_parallel(self.conv3, conv2, range(self.ngpu))
260 | conv4 = nn.parallel.data_parallel(self.conv4, conv3, range(self.ngpu))
261 | conv5 = nn.parallel.data_parallel(self.conv5, conv4, range(self.ngpu))
262 | conv6 = nn.parallel.data_parallel(self.conv6, conv5, range(self.ngpu))
263 | flat6 = conv6.view(-1, 4*4*512)
264 | fc_dis = nn.parallel.data_parallel(self.fc_dis, flat6, range(self.ngpu))
265 | fc_aux = nn.parallel.data_parallel(self.fc_aux, flat6, range(self.ngpu))
266 | else:
267 | conv1 = self.conv1(input)
268 | conv2 = self.conv2(conv1)
269 | conv3 = self.conv3(conv2)
270 | conv4 = self.conv4(conv3)
271 | conv5 = self.conv5(conv4)
272 | conv6 = self.conv6(conv5)
273 | flat6 = conv6.view(-1, 4*4*512)
274 | fc_dis = self.fc_dis(flat6)
275 | fc_aux = self.fc_aux(flat6)
276 | classes = self.softmax(fc_aux)
277 | realfake = self.sigmoid(fc_dis).view(-1, 1).squeeze(1)
278 | return realfake, classes
279 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # custom weights initialization called on netG and netD
2 | def weights_init(m):
3 | classname = m.__class__.__name__
4 | if classname.find('Conv') != -1:
5 | m.weight.data.normal_(0.0, 0.02)
6 | elif classname.find('BatchNorm') != -1:
7 | m.weight.data.normal_(1.0, 0.02)
8 | m.bias.data.fill_(0)
9 |
10 | # compute the current classification accuracy
11 | def compute_acc(preds, labels):
12 | correct = 0
13 | preds_ = preds.data.max(1)[1]
14 | correct = preds_.eq(labels.data).cpu().sum()
15 | acc = float(correct) / float(len(labels.data)) * 100.0
16 | return acc
17 |
--------------------------------------------------------------------------------