├── .gitignore ├── LICENSE ├── README.md ├── images ├── fashion.gif ├── fashion_hist.png ├── fashion_interp.png ├── mnist.gif ├── mnist_hist.png └── mnist_interp.png ├── interpolation.ipynb ├── main.py ├── models.py ├── saved_models ├── fashion_generator.pt └── mnist_generator.pt └── wgangp.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Pytorch WGAN-GP 2 | 3 | This is a pytorch implementation of [Improved Training of Wasserstein GANs](https://arxiv.org/abs/1704.00028). Most of the code was inspired by [this repository](https://github.com/EmilienDupont/wgan-gp) by [EmilienDupont](https://github.com/EmilienDupont). 4 | 5 | ## Training 6 | 7 | To train on the MNIST dataset, run 8 | ``` 9 | python main.py --dataset mnist --epochs 200 10 | ``` 11 | For the FashionMNIST dataset, run 12 | ``` 13 | python main.py --dataset fashion --epochs 200 14 | ``` 15 | You cans also set up a generator and discriminator pair and use the `WGANGP` class: 16 | ```python 17 | wgan = WGANGP(generator, discriminator, 18 | g_optimizer, d_optimizer, 19 | latent_shape, dataset_name) 20 | wgan.train(data_loader, n_epochs) 21 | ``` 22 | 23 | The argument `latent_shape` is the shape whatever the generator's forward function accepts as input. 24 | 25 | The training process is monitored by [tensorboardX](https://github.com/lanpa/tensorboardX). 26 | 27 | ## Results 28 | 29 | Here is the training history for both datasets: 30 | 31 | ![MNIST losses](images/mnist_hist.png) 32 | ![fashion losses](images/fashion_hist.png) 33 | 34 | Two gifs of the training process: 35 | 36 | ![MNIST training gif](images/mnist.gif) 37 | ![fashion training gif](images/fashion.gif) 38 | 39 | ## Interpolation in latent space 40 | 41 | We can generate samples going smoothly from one class to another by interpolating points on the latent space (done in [this notebook](interpolation.ipynb)): 42 | 43 | ![MNIST interpolation](images/mnist_interp.png) 44 | ![fashion interpolation](images/fashion_interp.png) 45 | 46 | The weights of the models are on the saved_models folder. 47 | -------------------------------------------------------------------------------- /images/fashion.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arturml/pytorch-wgan-gp/4f95d23039f512eb3c31d388d5b66ab7169f360c/images/fashion.gif -------------------------------------------------------------------------------- /images/fashion_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arturml/pytorch-wgan-gp/4f95d23039f512eb3c31d388d5b66ab7169f360c/images/fashion_hist.png -------------------------------------------------------------------------------- /images/fashion_interp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arturml/pytorch-wgan-gp/4f95d23039f512eb3c31d388d5b66ab7169f360c/images/fashion_interp.png -------------------------------------------------------------------------------- /images/mnist.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arturml/pytorch-wgan-gp/4f95d23039f512eb3c31d388d5b66ab7169f360c/images/mnist.gif -------------------------------------------------------------------------------- /images/mnist_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arturml/pytorch-wgan-gp/4f95d23039f512eb3c31d388d5b66ab7169f360c/images/mnist_hist.png -------------------------------------------------------------------------------- /images/mnist_interp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arturml/pytorch-wgan-gp/4f95d23039f512eb3c31d388d5b66ab7169f360c/images/mnist_interp.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import pandas as pd 4 | from torchvision.datasets import MNIST, FashionMNIST 5 | from torchvision import transforms 6 | from torch.utils.data import DataLoader, ConcatDataset 7 | from models import Generator, Discriminator 8 | from wgangp import WGANGP 9 | 10 | import warnings 11 | warnings.filterwarnings("ignore") 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'fashion']) 16 | parser.add_argument('--epochs', type=int, default=200) 17 | parser.add_argument('--batch_size', type=int, default=64) 18 | parser.add_argument('--use_cuda', type=str, default='True') 19 | 20 | args = parser.parse_args() 21 | 22 | transform = transforms.Compose([ 23 | transforms.Resize(32), 24 | transforms.ToTensor(), 25 | transforms.Normalize(mean=[0.5], std=[0.5]) 26 | ]) 27 | 28 | if args.dataset == 'mnist': 29 | train_dataset = MNIST(root='data/mnist', train=True, download=True, transform=transform) 30 | test_dataset = MNIST(root='data/mnist', train=False, download=True, transform=transform) 31 | 32 | if args.dataset == 'fashion': 33 | train_dataset = FashionMNIST(root='data/fashion', train=True, download=True, transform=transform) 34 | test_dataset = FashionMNIST(root='data/fashion', train=False, download=True, transform=transform) 35 | 36 | full_dataset = ConcatDataset([train_dataset, test_dataset]) 37 | data_loader = DataLoader(full_dataset, batch_size=args.batch_size, shuffle=True) 38 | 39 | generator = Generator(100) 40 | discriminator = Discriminator() 41 | 42 | g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.9)) 43 | d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.9)) 44 | 45 | wgan = WGANGP(generator, discriminator, g_optimizer, d_optimizer, [100, 1, 1], args.dataset) 46 | wgan.train(data_loader, args.epochs) 47 | 48 | pd.DataFrame(wgan.hist).to_csv(args.dataset + '_hist.csv', index=False) 49 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Generator(nn.Module): 5 | def __init__(self, latent_dim, d=32): 6 | super().__init__() 7 | self.net = nn.Sequential( 8 | nn.ConvTranspose2d(latent_dim, d * 8, 4, 1, 0), 9 | nn.BatchNorm2d(d * 8), 10 | nn.ReLU(True), 11 | 12 | nn.ConvTranspose2d(d * 8, d * 4, 4, 2, 1), 13 | nn.BatchNorm2d(d * 4), 14 | nn.ReLU(True), 15 | 16 | nn.ConvTranspose2d(d * 4, d * 2, 4, 2, 1), 17 | nn.BatchNorm2d(d * 2), 18 | nn.ReLU(True), 19 | 20 | nn.ConvTranspose2d(d * 2, 1, 4, 2, 1), 21 | nn.Tanh() 22 | ) 23 | 24 | def forward(self, x): 25 | return self.net(x) 26 | 27 | 28 | class Discriminator(nn.Module): 29 | def __init__(self, d=32): 30 | super().__init__() 31 | self.net = nn.Sequential( 32 | nn.Conv2d(1, d, 4, 2, 1), 33 | nn.InstanceNorm2d(d), 34 | nn.LeakyReLU(0.2), 35 | 36 | nn.Conv2d(d, d * 2, 4, 2, 1), 37 | nn.InstanceNorm2d(d * 2), 38 | nn.LeakyReLU(0.2), 39 | 40 | nn.Conv2d(d * 2, d * 4, 4, 2, 1), 41 | nn.InstanceNorm2d(d * 4), 42 | nn.LeakyReLU(0.2), 43 | 44 | nn.Conv2d(d * 4, 1, 4, 1, 0), 45 | ) 46 | 47 | def forward(self, x): 48 | outputs = self.net(x) 49 | return outputs.squeeze() 50 | -------------------------------------------------------------------------------- /saved_models/fashion_generator.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arturml/pytorch-wgan-gp/4f95d23039f512eb3c31d388d5b66ab7169f360c/saved_models/fashion_generator.pt -------------------------------------------------------------------------------- /saved_models/mnist_generator.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arturml/pytorch-wgan-gp/4f95d23039f512eb3c31d388d5b66ab7169f360c/saved_models/mnist_generator.pt -------------------------------------------------------------------------------- /wgangp.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torchvision.utils import make_grid 6 | from torch.autograd import Variable 7 | from torch import autograd 8 | from tensorboardX import SummaryWriter 9 | 10 | import warnings 11 | warnings.filterwarnings("ignore") 12 | 13 | class WGANGP(): 14 | def __init__(self, generator, discriminator, g_optmizer, d_optimizer, 15 | latent_shape, dataset_name, n_critic=5, gamma=10, 16 | save_every=20, use_cuda=True, logdir=None): 17 | 18 | self.G = generator 19 | self.D = discriminator 20 | self.G_opt = g_optmizer 21 | self.D_opt = d_optimizer 22 | self.latent_shape = latent_shape 23 | self.dataset_name = dataset_name 24 | self.n_critic = n_critic 25 | self.gamma = gamma 26 | self.save_every = save_every 27 | self.use_cuda = use_cuda 28 | self.writer = SummaryWriter(logdir) 29 | self.steps = 0 30 | self._fixed_z = torch.randn(64, *latent_shape) 31 | self.hist = [] 32 | self.images = [] 33 | 34 | if self.use_cuda: 35 | self._fixed_z = self._fixed_z.cuda() 36 | self.G.cuda() 37 | self.D.cuda() 38 | 39 | def train(self, data_loader, n_epochs): 40 | self._save_gif() 41 | for epoch in range(1, n_epochs + 1): 42 | print('Starting epoch {}...'.format(epoch)) 43 | self._train_epoch(data_loader) 44 | 45 | if epoch % self.save_every == 0 or epoch == n_epochs: 46 | torch.save(self.G.state_dict(), self.dataset_name + '_gen_{}.pt'.format(epoch)) 47 | torch.save(self.D.state_dict(), self.dataset_name + '_disc_{}.pt'.format(epoch)) 48 | 49 | def _train_epoch(self, data_loader): 50 | for i, (data, _) in enumerate(data_loader): 51 | self.steps += 1 52 | data = Variable(data) 53 | if self.use_cuda: 54 | data = data.cuda() 55 | 56 | d_loss, grad_penalty = self._discriminator_train_step(data) 57 | self.writer.add_scalars('losses', {'d_loss': d_loss, 'grad_penalty': grad_penalty}, self.steps) 58 | self.hist.append({'d_loss': d_loss, 'grad_penalty': grad_penalty}) 59 | 60 | if i % 200 == 0: 61 | img_grid = make_grid(self.G(self._fixed_z).cpu().data, normalize=True) 62 | self.writer.add_image('images', img_grid, self.steps) 63 | 64 | if self.steps % self.n_critic == 0: 65 | g_loss = self._generator_train_step(data.size(0)) 66 | self.writer.add_scalars('losses', {'g_loss': g_loss}, self.steps) 67 | self.hist[-1]['g_loss'] = g_loss 68 | 69 | print(' g_loss: {:.3f} d_loss: {:.3f} grad_penalty: {:.3f}'.format(g_loss, d_loss, grad_penalty)) 70 | 71 | def _discriminator_train_step(self, data): 72 | batch_size = data.size(0) 73 | generated_data = self._sample(batch_size) 74 | grad_penalty = self._gradient_penalty(data, generated_data) 75 | d_loss = self.D(generated_data).mean() - self.D(data).mean() + grad_penalty 76 | self.D_opt.zero_grad() 77 | d_loss.backward() 78 | self.D_opt.step() 79 | return d_loss.item(), grad_penalty.item() 80 | 81 | def _generator_train_step(self, batch_size): 82 | self.G_opt.zero_grad() 83 | generated_data = self._sample(batch_size) 84 | g_loss = -self.D(generated_data).mean() 85 | g_loss.backward() 86 | self.G_opt.step() 87 | return g_loss.item() 88 | 89 | def _gradient_penalty(self, data, generated_data, gamma=10): 90 | batch_size = data.size(0) 91 | epsilon = torch.rand(batch_size, 1, 1, 1) 92 | epsilon = epsilon.expand_as(data) 93 | 94 | 95 | if self.use_cuda: 96 | epsilon = epsilon.cuda() 97 | 98 | interpolation = epsilon * data.data + (1 - epsilon) * generated_data.data 99 | interpolation = Variable(interpolation, requires_grad=True) 100 | 101 | if self.use_cuda: 102 | interpolation = interpolation.cuda() 103 | 104 | interpolation_logits = self.D(interpolation) 105 | grad_outputs = torch.ones(interpolation_logits.size()) 106 | 107 | if self.use_cuda: 108 | grad_outputs = grad_outputs.cuda() 109 | 110 | gradients = autograd.grad(outputs=interpolation_logits, 111 | inputs=interpolation, 112 | grad_outputs=grad_outputs, 113 | create_graph=True, 114 | retain_graph=True)[0] 115 | 116 | gradients = gradients.view(batch_size, -1) 117 | gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) 118 | return self.gamma * ((gradients_norm - 1) ** 2).mean() 119 | 120 | def _sample(self, n_samples): 121 | z = Variable(torch.randn(n_samples, *self.latent_shape)) 122 | if self.use_cuda: 123 | z = z.cuda() 124 | return self.G(z) 125 | 126 | def _save_gif(self): 127 | grid = make_grid(self.G(self._fixed_z).cpu().data, normalize=True) 128 | grid = np.transpose(grid.numpy(), (1, 2, 0)) 129 | self.images.append(grid) 130 | imageio.mimsave('{}.gif'.format(self.dataset_name), self.images) 131 | --------------------------------------------------------------------------------