├── main.py ├── models.py ├── readme.md └── utils.py /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import utils 3 | import models 4 | import torch.nn as nn 5 | import torch 6 | from torch.autograd import Variable 7 | import torchvision.utils as vutils 8 | 9 | train_loader = utils.load_data_CIFAR10() 10 | 11 | if not os.path.exists('./result'): 12 | os.mkdir('result/') 13 | 14 | if not os.path.exists('./model'): 15 | os.mkdir('model/') 16 | 17 | netG = models.get_netG() 18 | netD1 = models.get_netD() 19 | netD2 = models.get_netD() 20 | 21 | # setup optimizer 22 | optimizerD1 = torch.optim.Adam(netD1.parameters(), lr=0.0002, betas=(0.5, 0.999)) 23 | optimizerD2 = torch.optim.Adam(netD2.parameters(), lr=0.0002, betas=(0.5, 0.999)) 24 | optimizerG = torch.optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999)) 25 | 26 | criterion_log = utils.Log_loss() 27 | criterion_itself = utils.Itself_loss() 28 | 29 | input = torch.FloatTensor(64, 3, 64, 64) 30 | noise = torch.FloatTensor(64, 100, 1, 1) 31 | fixed_noise = torch.FloatTensor(64, 100, 1, 1).normal_(0, 1) 32 | fixed_noise = Variable(fixed_noise) 33 | 34 | use_cuda = torch.cuda.is_available() 35 | if use_cuda: 36 | criterion_log, criterion_itself = criterion_log.cuda(), criterion_itself.cuda() 37 | input= input.cuda() 38 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 39 | 40 | for epoch in range(200): 41 | for i, data in enumerate(train_loader): 42 | real_cpu, _ = data 43 | batch_size = real_cpu.size(0) 44 | ###################################### 45 | # train D1 and D2 46 | ##################################### 47 | 48 | netD1.zero_grad() 49 | netD2.zero_grad() 50 | # train with real 51 | if use_cuda: 52 | real_cpu = real_cpu.cuda() 53 | 54 | input.resize_as_(real_cpu).copy_(real_cpu) 55 | inputv = Variable(input) 56 | 57 | # D1 sees real as real, minimize -logD1(x) 58 | output = netD1(inputv) 59 | errD1_real = 0.2 * criterion_log(output)#criterion(output1, labelv) * 0.2 60 | errD1_real.backward() 61 | 62 | # D2 sees real as fake, minimize D2(x) 63 | output = netD2(inputv) 64 | errD2_real = criterion_itself(output, False) 65 | errD2_real.backward() 66 | 67 | # train with fake 68 | noise.resize_(batch_size, 100, 1, 1).normal_(0,1) 69 | noisev = Variable(noise) 70 | fake = netG(noisev) 71 | 72 | # D1 sees fake as fake, minimize D1(G(z)) 73 | output = netD1(fake.detach()) 74 | errD1_fake = criterion_itself(output, False) 75 | errD1_fake.backward() 76 | 77 | # D2 sees fake as real, minimize -log(D2(G(z)) 78 | output = netD2(fake.detach()) 79 | errD2_fake = 0.1 * criterion_log(output) 80 | errD2_fake.backward() 81 | 82 | optimizerD1.step() 83 | optimizerD2.step() 84 | 85 | ################################## 86 | # train G 87 | ################################## 88 | netG.zero_grad() 89 | # G: minimize -D1(G(z)): to make D1 see fake as real 90 | output = netD1(fake) 91 | errG1 = criterion_itself(output) 92 | 93 | # G: minimize logD2(G(z)): to make D2 see fake as fake 94 | output = netD2(fake) 95 | errG2 = criterion_log(output, False) 96 | 97 | errG = errG2*0.1 + errG1 98 | errG.backward() 99 | optimizerG.step() 100 | 101 | if ((i+1) % 200 == 0): 102 | print(i+1, "step") 103 | print(str(errG1.data[0]) + " " + str(errG2.data[0]*0.1)) 104 | fake = netG(fixed_noise) 105 | if use_cuda: 106 | vutils.save_image(fake.cpu().data, '%s/fake_samples_epoch_%s.png' % ('result', str(epoch)+"_"+str(i+1)), normalize=True) 107 | else: 108 | vutils.save_image(fake.data, '%s/fake_samples_epoch_%s.png' % ('result', str(epoch)+"_"+str(i+1)), normalize=True) 109 | print("%s epoch finished" % (str(epoch))) 110 | print("-----------------------------------------------------------------\n") 111 | fake = netG(fixed_noise) 112 | if use_cuda: 113 | vutils.save_image(fake.cpu().data, '%s/fake_samples_epoch_%s.png' % ('result', str(epoch)+"_"+str(i+1)), normalize=True) 114 | else: 115 | vutils.save_image(fake.data, '%s/fake_samples_epoch_%s.png' % ('result', str(epoch)+"_"+str(i+1)), normalize=True) 116 | torch.save(netG.state_dict(), '%s/netG.pth' % ('model')) 117 | torch.save(netD1.state_dict(), '%s/netD1.pth' % ('model')) 118 | torch.save(netD2.state_dict(), '%s/netD2.pth' % ('model')) 119 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | nz = 100 5 | nc = 3 6 | ngf = 64 7 | ndf = 64 8 | 9 | class _netG(nn.Module): 10 | def __init__(self): 11 | super(_netG, self).__init__() 12 | self.main = nn.Sequential( 13 | 14 | # Z 15 | nn.ConvTranspose2d(nz, ngf*8, 2, 1, 0, bias=False), 16 | nn.BatchNorm2d(ngf * 8), 17 | nn.ReLU(True), 18 | 19 | # (ngf * 8) x 2 x 2 20 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 21 | nn.BatchNorm2d(ngf * 4), 22 | nn.ReLU(True), 23 | 24 | # (ngf * 4) x 4 x 4 25 | nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False), 26 | nn.BatchNorm2d(ngf*2), 27 | nn.ReLU(True), 28 | 29 | # (ngf * 2) x 8 x 8 30 | nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False), 31 | nn.BatchNorm2d(ngf), 32 | nn.ReLU(), 33 | 34 | # ngf x 16 x 16 35 | nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), 36 | nn.Tanh() 37 | ) 38 | def forward(self, input): 39 | output = self.main(input) 40 | return output 41 | 42 | class _netD(nn.Module): 43 | def __init__(self): 44 | super(_netD, self).__init__() 45 | self.main = nn.Sequential( 46 | # (nc) x 32 x 32 47 | nn.Conv2d(nc, ndf, 4,2,1,bias=False), 48 | nn.LeakyReLU(0.2, inplace=True), 49 | 50 | # ndf x 16 x 16 51 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 52 | nn.BatchNorm2d(ndf * 2), 53 | nn.LeakyReLU(0.2, inplace=True), 54 | 55 | # (ndf * 2) x 8 x 8 56 | nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False), 57 | nn.BatchNorm2d(ndf * 4), 58 | nn.LeakyReLU(0.2, inplace=True), 59 | 60 | # (ndf * 4) x 4 x 4 61 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 62 | nn.BatchNorm2d(ndf*8), 63 | nn.LeakyReLU(0.2, inplace=True), 64 | 65 | # (ndf * 8) x 2 x 2 66 | nn.Conv2d(ndf*8, 1, 2, 1, 0, bias=False), 67 | nn.Softplus() 68 | ) 69 | def forward(self, input): 70 | output = self.main(input) 71 | return output.view(-1, 1).squeeze(1) 72 | 73 | def weights_init(m): 74 | classname = m.__class__.__name__ 75 | if classname.find('Conv') != -1: # Conv weight init 76 | m.weight.data.normal_(0.0, 0.01) 77 | elif classname.find('BatchNorm') != -1: # BatchNorm weight init 78 | m.weight.data.normal_(1.0, 0.02) 79 | m.bias.data.fill_(0) 80 | 81 | def get_netG(): 82 | use_cuda = torch.cuda.is_available() 83 | netG = _netG() 84 | netG.apply(weights_init) 85 | if use_cuda: 86 | print("USE CUDA") 87 | netG.cuda() 88 | return netG 89 | 90 | def get_netD(): 91 | use_cuda = torch.cuda.is_available() 92 | netD = _netD() 93 | netD.apply(weights_init) 94 | if use_cuda: 95 | print("USE CUDA") 96 | netD.cuda() 97 | return netD 98 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Implementation of Dual-Discriminator GAN 2 | 3 | This is the implementation of the D2GAN(Dual-Discriminator GAN) using pytorch, [paper](https://arxiv.org/abs/1709.03831) 4 | 5 | ![image](https://user-images.githubusercontent.com/25279765/35868536-1d145ac0-0ba0-11e8-8a88-87783989490a.png) 6 | 7 | It uses two discriminator. one sees real data as real, however another one sees real data as fake. 8 | 9 | According to the paper, loss function is different from ordinary GAN. Therefore, I made custom loss function. It might be unstable(e.g. explode when training, converges into inf/-inf) 10 | 11 | Result is sames as below. I used two datasets, STL10 and CIFAR10 12 | 13 | ![fake_samples_epoch_61_600](https://user-images.githubusercontent.com/25279765/35868602-5302ee76-0ba0-11e8-9e64-4d46d34f1030.png) 14 | > result from 64x64 STL10, generated by model 15 | 16 | ![fake_samples_epoch_46_200](https://user-images.githubusercontent.com/25279765/35868662-818caf02-0ba0-11e8-8f13-10acf05277e7.png) 17 | > result from 64x64 CIFAR10, generated by model 18 | 19 | In the paper, they used 32x32 size images, so I converted DCGAN's architecture, changing kernel size of Ds and G's first layers to 4. 20 | 21 | its result is same as below. 22 | 23 | ![fake_samples_epoch_42_400](https://user-images.githubusercontent.com/25279765/35869139-bbc53904-0ba1-11e8-9c36-3fe783512ec1.png) 24 | > result from 32x32 CIFAR10, generated by model 25 | 26 | ![image](https://user-images.githubusercontent.com/25279765/35869225-ecbb20d2-0ba1-11e8-9bfe-6cb263897e13.png) 27 | > result from 32x32 CIFAR10, from the paper 28 | 29 | ![fake_samples_epoch_30_200](https://user-images.githubusercontent.com/25279765/35879666-216ffd94-0bbf-11e8-96c4-b7c06c2000ee.png) 30 | > result from 32x32 STL10, generated by model 31 | 32 | show much better result on 32x32 data. Also converges much faster than ordinary GAN 33 | 34 | ## Usage 35 | 36 | On your console/terminal, type 37 | 38 | ````python 39 | python main.py 40 | ```` 41 | 42 | ## TODO 43 | - revise custom loss(doesn't update its state for long iteration) 44 | - clamp networks' gradient 45 | - apply to more datasets 46 | - argparser 47 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch 3 | import torchvision.datasets as dsets 4 | import torchvision.utils as vutils 5 | import torchvision.transforms as transforms 6 | 7 | transform = transforms.Compose([ 8 | transforms.Resize(32), 9 | transforms.ToTensor(), 10 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 11 | ]) 12 | 13 | def load_data_STL10(): 14 | train_dataset =dsets.STL10(root='./data/', split='train+unlabeled', download=True, transform=transform) 15 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) 16 | return train_loader 17 | 18 | def load_data_CIFAR10(): 19 | train_dataset = dsets.CIFAR10(root='./data/', train=True,download=True, transform=transform) 20 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) 21 | return train_loader 22 | 23 | class Log_loss(torch.nn.Module): 24 | def __init__(self): 25 | # negation is true when you minimize -log(val) 26 | super(Log_loss, self).__init__() 27 | 28 | def forward(self, x, negation=True): 29 | # shape of x will be [batch size] 30 | log_val = torch.log(x) 31 | loss = torch.sum(log_val) 32 | if negation: 33 | loss = torch.neg(loss) 34 | return loss 35 | 36 | class Itself_loss(torch.nn.Module): 37 | def __init__(self): 38 | super(Itself_loss, self).__init__() 39 | 40 | def forward(self, x, negation=True): 41 | # shape of x will be [batch size] 42 | loss = torch.sum(x) 43 | if negation: 44 | loss = torch.neg(loss) 45 | return loss 46 | --------------------------------------------------------------------------------