├── LICENSE ├── README.md ├── dragan.py └── vanilla_gan.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 João Felipe Santos 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 9 | of the Software, and to permit persons to whom the Software is furnished to do 10 | so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dragan-pytorch 2 | PyTorch implementation of DRAGAN (https://arxiv.org/abs/1705.07215) 3 | 4 | Code based on the [original implementation](https://github.com/kodalinaveen3/DRAGAN) by the authors. 5 | 6 | The following repositories were also used as a reference on how to implement the gradient penalty in PyTorch: 7 | - https://github.com/t-vi/pytorch-tvmisc 8 | - https://github.com/caogang/wgan-gp 9 | -------------------------------------------------------------------------------- /dragan.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import argparse 8 | 9 | import numpy as np 10 | import torch 11 | from torch.autograd import Variable, grad 12 | from torch.nn.init import xavier_normal 13 | from torchvision import datasets, transforms 14 | import torchvision.utils as vutils 15 | 16 | def xavier_init(model): 17 | for param in model.parameters(): 18 | if len(param.size()) == 2: 19 | xavier_normal(param) 20 | 21 | 22 | if __name__ == '__main__': 23 | batch_size = 128 24 | z_dim = 100 25 | h_dim = 128 26 | y_dim = 784 27 | max_epochs = 1000 28 | lambda_ = 10 29 | 30 | train_loader = torch.utils.data.DataLoader( 31 | datasets.MNIST('../data', train=True, download=True, 32 | transform=transforms.Compose([ 33 | transforms.ToTensor() 34 | ])), 35 | batch_size=batch_size, shuffle=True, drop_last=True) 36 | test_loader = torch.utils.data.DataLoader( 37 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 38 | transforms.ToTensor() 39 | ])), 40 | batch_size=batch_size, shuffle=False, drop_last=True) 41 | 42 | generator = torch.nn.Sequential(torch.nn.Linear(z_dim, h_dim), 43 | torch.nn.Sigmoid(), 44 | torch.nn.Linear(h_dim, y_dim), 45 | torch.nn.Sigmoid()) 46 | 47 | discriminator = torch.nn.Sequential(torch.nn.Linear(y_dim, h_dim), 48 | torch.nn.Sigmoid(), 49 | torch.nn.Linear(h_dim, 1), 50 | torch.nn.Sigmoid()) 51 | 52 | # Init weight matrices (xavier_normal) 53 | xavier_init(generator) 54 | xavier_init(discriminator) 55 | 56 | opt_g = torch.optim.Adam(generator.parameters()) 57 | opt_d = torch.optim.Adam(discriminator.parameters()) 58 | 59 | criterion = torch.nn.BCELoss() 60 | X = Variable(torch.FloatTensor(batch_size, y_dim)) 61 | z = Variable(torch.FloatTensor(batch_size, z_dim)) 62 | labels = Variable(torch.FloatTensor(batch_size)) 63 | 64 | # Train 65 | for epoch in range(max_epochs): 66 | for batch_idx, (data, target) in enumerate(train_loader): 67 | X.data.copy_(data) 68 | 69 | # Update discriminator 70 | # train with real 71 | discriminator.zero_grad() 72 | pred_real = discriminator(X) 73 | labels.data.fill_(1.0) 74 | loss_d_real = criterion(pred_real, labels) 75 | loss_d_real.backward() 76 | 77 | # train with fake 78 | z.data.normal_(0, 1) 79 | fake = generator.forward(z).detach() 80 | pred_fake = discriminator(fake) 81 | labels.data.fill_(0.0) 82 | loss_d_fake = criterion(pred_fake, labels) 83 | loss_d_fake.backward() 84 | 85 | # gradient penalty 86 | alpha = torch.rand(batch_size, 1).expand(X.size()) 87 | x_hat = Variable(alpha * X.data + (1 - alpha) * (X.data + 0.5 * X.data.std() * torch.rand(X.size())), requires_grad=True) 88 | pred_hat = discriminator(x_hat) 89 | gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()), 90 | create_graph=True, retain_graph=True, only_inputs=True)[0] 91 | gradient_penalty = lambda_ * ((gradients.norm(2, dim=1) - 1) ** 2).mean() 92 | gradient_penalty.backward() 93 | 94 | loss_d = loss_d_real + loss_d_fake + gradient_penalty 95 | opt_d.step() 96 | 97 | # Update generator 98 | generator.zero_grad() 99 | z.data.normal_(0, 1) 100 | gen = generator(z) 101 | pred_gen = discriminator(gen) 102 | labels.data.fill_(1.0) 103 | loss_g = criterion(pred_gen, labels) 104 | loss_g.backward() 105 | opt_g.step() 106 | 107 | print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' 108 | % (epoch, max_epochs, batch_idx, len(train_loader), 109 | loss_d.data[0], loss_g.data[0])) 110 | 111 | if batch_idx % 100 == 0: 112 | vutils.save_image(data, 113 | 'samples/real_samples.png') 114 | fake = generator(z) 115 | vutils.save_image(gen.data.view(batch_size, 1, 28, 28), 116 | 'samples/fake_samples_epoch_%03d.png' % epoch) 117 | 118 | 119 | -------------------------------------------------------------------------------- /vanilla_gan.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import argparse 8 | 9 | import numpy as np 10 | import torch 11 | from torch.autograd import Variable, grad 12 | from torch.nn.init import xavier_normal 13 | from torchvision import datasets, transforms 14 | import torchvision.utils as vutils 15 | 16 | def xavier_init(model): 17 | for param in model.parameters(): 18 | if len(param.size()) == 2: 19 | xavier_normal(param) 20 | 21 | 22 | if __name__ == '__main__': 23 | batch_size = 128 24 | z_dim = 100 25 | h_dim = 128 26 | y_dim = 784 27 | max_epochs = 1000 28 | 29 | train_loader = torch.utils.data.DataLoader( 30 | datasets.MNIST('../data', train=True, download=True, 31 | transform=transforms.Compose([ 32 | transforms.ToTensor() 33 | ])), 34 | batch_size=batch_size, shuffle=True, drop_last=True) 35 | test_loader = torch.utils.data.DataLoader( 36 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 37 | transforms.ToTensor() 38 | ])), 39 | batch_size=batch_size, shuffle=False, drop_last=True) 40 | 41 | generator = torch.nn.Sequential(torch.nn.Linear(z_dim, h_dim), 42 | torch.nn.Sigmoid(), 43 | torch.nn.Linear(h_dim, y_dim), 44 | torch.nn.Sigmoid()) 45 | 46 | discriminator = torch.nn.Sequential(torch.nn.Linear(y_dim, h_dim), 47 | torch.nn.Sigmoid(), 48 | torch.nn.Linear(h_dim, 1), 49 | torch.nn.Sigmoid()) 50 | 51 | # Init weight matrices (xavier_normal) 52 | xavier_init(generator) 53 | xavier_init(discriminator) 54 | 55 | opt_g = torch.optim.Adam(generator.parameters()) 56 | opt_d = torch.optim.Adam(discriminator.parameters()) 57 | 58 | criterion = torch.nn.BCELoss() 59 | X = Variable(torch.FloatTensor(batch_size, y_dim)) 60 | z = Variable(torch.FloatTensor(batch_size, z_dim)) 61 | labels = Variable(torch.FloatTensor(batch_size)) 62 | 63 | # Train 64 | for epoch in range(max_epochs): 65 | for batch_idx, (data, target) in enumerate(train_loader): 66 | X.data.copy_(data) 67 | 68 | # Update discriminator 69 | # train with real 70 | discriminator.zero_grad() 71 | pred_real = discriminator(X) 72 | labels.data.fill_(1.0) 73 | loss_d_real = criterion(pred_real, labels) 74 | loss_d_real.backward() 75 | 76 | # train with fake 77 | z.data.normal_(0, 1) 78 | fake = generator.forward(z).detach() 79 | pred_fake = discriminator(fake) 80 | labels.data.fill_(0.0) 81 | loss_d_fake = criterion(pred_fake, labels) 82 | loss_d_fake.backward() 83 | 84 | loss_d = loss_d_real + loss_d_fake 85 | opt_d.step() 86 | 87 | # Update generator 88 | generator.zero_grad() 89 | z.data.normal_(0, 1) 90 | gen = generator(z) 91 | pred_gen = discriminator(gen) 92 | labels.data.fill_(1.0) 93 | loss_g = criterion(pred_gen, labels) 94 | loss_g.backward() 95 | opt_g.step() 96 | 97 | print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' 98 | % (epoch, max_epochs, batch_idx, len(train_loader), 99 | loss_d.data[0], loss_g.data[0])) 100 | 101 | if batch_idx % 100 == 0: 102 | vutils.save_image(data, 103 | 'samples_vanilla/real_samples.png') 104 | fake = generator(z) 105 | vutils.save_image(gen.data.view(batch_size, 1, 28, 28), 106 | 'samples_vanilla/fake_samples_epoch_%03d.png' % epoch) 107 | 108 | 109 | --------------------------------------------------------------------------------