├── LICENSE.md ├── README.md ├── example.png ├── fake_samples_epoch_300.png └── improved_GAN.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Yiliang Chen. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # improved_gan_pytorch 2 | Pytorch implementation of semi-supervised DCGAN based on "[Improved Techniques for Training GANs](http://arxiv.org/abs/1606.03498)". 3 | 4 | Feature matching and semi-supervised GAN have be reimplemented. 5 | 6 | So far, other improved techniques haven't been added. 7 | 8 | # Prerequisites 9 | Pytorch Version: 2.0.3 and Python 2.7 10 | 11 | # Usage 12 | Run file: python improved_GAN.py 13 | 14 | BTW, in my example, my classifer is for CIFAR10 dataset, 15 | 16 | and labeled input : unlabeled input : generated fake input = 1 : 1 : 1 17 | 18 | Users also can change the settings according to my program's comments. 19 | 20 | P.S. 21 | 22 | For Generator Loss, it is also equal to -loss_unlabled_fake + loss_feature_matching. 23 | 24 | For Labeled Loss, it is also equal to -loss_target + log_sum_exp(before_softmax_labeled_output) 25 | 26 | # To do 27 | 1. to average input labeled data over 10 classes subset. 28 | 2. to adjust the network structure for high accuracy classification 29 | 3. to reimplement other techniques in improved GAN 30 | 4. to reimplement "[Bad GAN](https://arxiv.org/abs/1705.09783)" paper 31 | 32 | # Semi-supervised + Feature matching CIFAR10 Classification 33 | 34 | ![image](https://github.com/eli5168/improved_gan_pytorch/blob/master/example.png) 35 | 36 | ![image](https://github.com/eli5168/improved_gan_pytorch/blob/master/fake_samples_epoch_300.png) 37 | 38 | 300th epoch 39 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eli5168/improved_gan_pytorch/ccb018b77274ea21d7e975087f5b52287626a3d1/example.png -------------------------------------------------------------------------------- /fake_samples_epoch_300.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eli5168/improved_gan_pytorch/ccb018b77274ea21d7e975087f5b52287626a3d1/fake_samples_epoch_300.png -------------------------------------------------------------------------------- /improved_GAN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim as optim 9 | import torch.utils.data 10 | import torchvision.datasets as dset 11 | import torch.nn.functional as F 12 | #import torchvision.transforms as transforms 13 | import torchvision.utils as vutils 14 | from torchvision import datasets, transforms 15 | from torch.autograd import Variable 16 | 17 | batch_size = 200 #training batch_size 18 | momentum = 0.5 #adam parameter 19 | epochs = 500 # number of epochs (training) 20 | lr = 0.0001 #learning rate 21 | imageSize = 32 # resize input image to XX 22 | nz = 100 #size of the latent z vector 23 | ngf = 32 #number of G output filters 24 | ndf = 32 #number of D output filters 25 | nc = 3 # numbel of channel 26 | outf = './fake' #folder to output images and model checkpoints 27 | log_interval = 1000 28 | test_interval = 5000 29 | continue_netD = '' #"path to netD (to continue training" 30 | continue_netG = '' #"path to netG (to continue training" 31 | cudnn.benchmark = True 32 | fsave = open('accuracy.txt','w') 33 | 34 | if os.path.isdir('./fake'): 35 | pass 36 | else: 37 | os.mkdir('./fake') 38 | 39 | dataloader = torch.utils.data.DataLoader( 40 | datasets.CIFAR10('../data', train=True, download=True, 41 | transform=transforms.Compose([ 42 | transforms.Scale(imageSize), 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 45 | ])) 46 | ,batch_size=batch_size, shuffle=True, num_workers=4) 47 | 48 | testloader = torch.utils.data.DataLoader( 49 | datasets.CIFAR10('../data', train=False, download=True, 50 | transform=transforms.Compose([ 51 | transforms.Scale(imageSize), 52 | transforms.ToTensor(), 53 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 54 | ])) 55 | ,batch_size=batch_size, shuffle=False, num_workers=4) 56 | 57 | 58 | # custom weights initialization called on netG and netD 59 | def weights_init(m): 60 | classname = m.__class__.__name__ 61 | if classname.find('Conv') != -1: 62 | m.weight.data.normal_(0.0, 0.02) 63 | elif classname.find('BatchNorm') != -1: 64 | m.weight.data.normal_(1.0, 0.02) 65 | m.bias.data.fill_(0) 66 | 67 | 68 | class _netG(nn.Module): 69 | def __init__(self): 70 | super(_netG, self).__init__() 71 | self.main = nn.Sequential( 72 | # input is Z, going into a convolution 73 | nn.ConvTranspose2d( nz, ngf * 8, 4, 2, 1, bias=False), # h = (1-1) * 4 - 0 + 4 + 0 = 4 , w = 0 * 1 - 2 * 0 + 4 + 0 = 4 74 | #nn.BatchNorm2d(ngf * 8), 75 | nn.ReLU(True), 76 | # state size. (ngf*8) x 4 x 4 77 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), # h = (4-1) * 2 - 2*1 + 4 + 0 = 8 78 | nn.BatchNorm2d(ngf * 4), 79 | nn.ReLU(True), 80 | # state size. (ngf*4) x 8 x 8 81 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 82 | nn.BatchNorm2d(ngf * 2), 83 | nn.ReLU(True), 84 | # state size. (ngf*2) x 16 x 16 85 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 86 | nn.BatchNorm2d(ngf), 87 | nn.ReLU(True), 88 | # state size. (ngf) x 32 x 32 89 | nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), 90 | nn.Tanh() 91 | # state size. (nc) x 64 x 64 92 | ) 93 | 94 | def forward(self, input): 95 | output = self.main(input) 96 | return output 97 | 98 | 99 | netG = _netG() 100 | #netG.apply(weights_init) 101 | 102 | 103 | 104 | class _netD(nn.Module): 105 | def __init__(self): 106 | super(_netD, self).__init__() 107 | self.main = nn.Sequential( 108 | # input is (nc) x 32 x 32 109 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), 110 | nn.LeakyReLU(0.2, inplace=True), 111 | # state size. (ndf) x 16 x 16 112 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 113 | nn.BatchNorm2d(ndf * 2), 114 | nn.LeakyReLU(0.2, inplace=True), 115 | # state size. (ndf*2) x 8 x 8 116 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 117 | nn.BatchNorm2d(ndf * 4), 118 | nn.LeakyReLU(0.2, inplace=True), 119 | # state size. (ndf*4) x 4 x 4 120 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 121 | nn.BatchNorm2d(ndf * 8), 122 | ) 123 | self.main2 = nn.Sequential( 124 | nn.Linear(1024,10), 125 | ) 126 | self.main3 = nn.Sequential( 127 | nn.Sigmoid() 128 | ) 129 | def forward(self, input,matching = False): 130 | output = self.main(input) 131 | feature = output.view(-1,1024) 132 | output = self.main2(feature) 133 | #output = self.main3(output) 134 | if matching == True: 135 | return feature,output 136 | else: 137 | return output #batch_size x 1 x 1 x 1 => batch_size 138 | 139 | 140 | netD = _netD() 141 | #netD.apply(weights_init) 142 | # if continue_netD != '': 143 | # netD.load_state_dict(torch.load(continue_netD)) 144 | # print(netD) 145 | 146 | def to_scalar(var): 147 | # returns a python float 148 | return var.view(-1).data.tolist()[0] 149 | 150 | def argmax(vec): 151 | # return the argmax as a python int 152 | _, idx = torch.max(vec, 1) 153 | return to_scalar(idx) 154 | 155 | #log_sum_exp function 156 | def LSE(before_softmax_output): 157 | # exp = torch.exp(before_softmax_output) 158 | # sum_exp = torch.sum(exp,1) #right 159 | # log_sum_exp = torch.log(sum_exp) 160 | # return log_sum_exp 161 | vec = before_softmax_output 162 | max_score = vec[0, argmax(vec)] 163 | max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1]) 164 | output = max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast),1)) 165 | return output 166 | 167 | def test(): 168 | netD.eval() 169 | test_loss = 0 170 | correct = 0 171 | val_len = 10000 172 | for data, target in testloader: 173 | data, target = data.cuda(), torch.LongTensor(target).cuda() 174 | data, target = Variable(data, volatile=True), Variable(target) 175 | output = netD(data) 176 | test_loss += F.cross_entropy(output, target, size_average=False).data[0] # sum up batch loss 177 | #test_loss += torch.nn.MultiLabelSoftMarginLoss(output, target, size_average=False).data[0] # sum up batch loss 178 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 179 | correct += pred.eq(target.data.view_as(pred)).cuda().sum() 180 | test_loss /= len(testloader.dataset) 181 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 182 | test_loss, correct, len(testloader.dataset), 183 | 100. * correct / len(testloader.dataset))) 184 | print >> fsave,'\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(testloader.dataset),100. * correct / len(testloader.dataset)) 185 | 186 | criterionD = nn.CrossEntropyLoss() # binary cross-entropy 187 | criterionG = nn.MSELoss() 188 | input = torch.FloatTensor(batch_size, 3, imageSize, imageSize) 189 | input_label = torch.FloatTensor(batch_size) 190 | noise = torch.FloatTensor(batch_size, nz, 1, 1) 191 | fixed_noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1) #normal_(mean=0, std=1, *, generator=None) 192 | label = torch.FloatTensor(batch_size) 193 | 194 | 195 | 196 | netD.cuda() 197 | netG.cuda() 198 | criterionD.cuda() 199 | criterionG.cuda() 200 | input, label = input.cuda(), label.cuda() 201 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 202 | 203 | fixed_noise = Variable(fixed_noise) # A fixed (mean, variance) noise distribution # just for testing 204 | 205 | # setup optimizer 206 | optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(momentum, 0.999)) 207 | optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(momentum, 0.999)) 208 | 209 | #dataloader => batchsize, data, target 210 | for epoch in range(1, epochs + 1): 211 | for i, data in enumerate(dataloader): 212 | ############################ 213 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) => same as BCELoss 214 | ########################### 215 | # train with real 216 | #For classifer, it contains three loss functions 217 | netD.zero_grad()#set unlearned parameters' gradience to zero. 218 | real_data, real_label = data 219 | batch_size = real_data.size(0) 220 | real_data = real_data.cuda() 221 | input.resize_as_(real_data).copy_(real_data) 222 | inputv = Variable(input) 223 | real_labelv = Variable(real_label) 224 | input_labelv = Variable(input_label) 225 | label_input = inputv[:100] 226 | unlabel_input = inputv[100:] 227 | l_label = real_label[:100] 228 | l_labelv = Variable(l_label).cuda() 229 | l_output = netD(label_input) 230 | loss_label = criterionD(l_output, l_labelv) 231 | unl_output = netD(unlabel_input) 232 | loss_unl_real = -torch.mean(LSE(unl_output),0) + torch.mean(F.softplus(LSE(unl_output),1),0) 233 | #train with fake 234 | noise.resize_(batch_size/2, nz, 1, 1).normal_(0, 1) 235 | noisev = Variable(noise) 236 | fake = netG(noisev) 237 | unl_output = netD(fake.detach()) #fake images are separated from the graph #results will never gradient(be updated), so G will not be updated 238 | loss_unl_fake = torch.mean(F.softplus(LSE(unl_output),1),0) 239 | loss_D = loss_label + loss_unl_real + loss_unl_fake 240 | loss_D.backward()# because detach(), backward() will not influence netG 241 | optimizerD.step() 242 | 243 | ############################ 244 | # (2) Update G network: maximize log(D(G(z))) 245 | ########################### 246 | netG.zero_grad() 247 | #labelv = Variable(label.fill_(real_label)) # fake labels are real for generator cost 248 | ####### feature matching ######## 249 | feature_real,_ = netD(inputv.detach(),matching=True) 250 | feature_fake,output = netD(fake,matching=True) 251 | feature_real = torch.mean(feature_real,0) 252 | feature_fake = torch.mean(feature_fake,0) 253 | loss_G = criterionG(feature_fake, feature_real.detach()) 254 | ####### feature matching ######## 255 | loss_G.backward() 256 | optimizerG.step() 257 | 258 | print('[%d/%d][%d/%d] Loss_label: %.4f Loss_unlabel_real: %.4f Loss_fake: %.4f Loss_D: %.4f Loss_G: %.4f' 259 | % (epoch, epochs, i, len(dataloader),loss_label.data[0], loss_unl_real.data[0], loss_unl_fake.data[0], loss_D.data[0], loss_G.data[0])) 260 | if i % log_interval == 0: 261 | vutils.save_image(real_data,'%s/real_samples.png' % outf,normalize=True) 262 | fake = netG(fixed_noise) #just for test 263 | vutils.save_image(fake.data,'%s/fake_samples_epoch_%03d.png' % (outf, epoch), normalize=True) # batch_size grid 264 | # .data => transfer Variable() => matrix 265 | if i % test_interval == 0: 266 | test() 267 | 268 | # do checkpointing 269 | torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (outf, epoch)) 270 | torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (outf, epoch)) 271 | 272 | fsave.close() --------------------------------------------------------------------------------