├── assets └── PyTorch │ ├── loss_plot.png │ ├── generated_images.gif │ └── generated_images │ ├── epoch_0.png │ ├── epoch_120.png │ ├── epoch_160.png │ ├── epoch_200.png │ ├── epoch_40.png │ └── epoch_80.png ├── LICENSE ├── src └── PyTorch │ ├── input_data.py │ └── gan-mnist-pytorch.py ├── .gitignore └── README.md /assets/PyTorch/loss_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vamsi3/simple-GAN/HEAD/assets/PyTorch/loss_plot.png -------------------------------------------------------------------------------- /assets/PyTorch/generated_images.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vamsi3/simple-GAN/HEAD/assets/PyTorch/generated_images.gif -------------------------------------------------------------------------------- /assets/PyTorch/generated_images/epoch_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vamsi3/simple-GAN/HEAD/assets/PyTorch/generated_images/epoch_0.png -------------------------------------------------------------------------------- /assets/PyTorch/generated_images/epoch_120.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vamsi3/simple-GAN/HEAD/assets/PyTorch/generated_images/epoch_120.png -------------------------------------------------------------------------------- /assets/PyTorch/generated_images/epoch_160.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vamsi3/simple-GAN/HEAD/assets/PyTorch/generated_images/epoch_160.png -------------------------------------------------------------------------------- /assets/PyTorch/generated_images/epoch_200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vamsi3/simple-GAN/HEAD/assets/PyTorch/generated_images/epoch_200.png -------------------------------------------------------------------------------- /assets/PyTorch/generated_images/epoch_40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vamsi3/simple-GAN/HEAD/assets/PyTorch/generated_images/epoch_40.png -------------------------------------------------------------------------------- /assets/PyTorch/generated_images/epoch_80.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vamsi3/simple-GAN/HEAD/assets/PyTorch/generated_images/epoch_80.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Vamsi Krishna 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/PyTorch/input_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for downloading and reading MNIST data.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | import gzip 23 | import os 24 | import tempfile 25 | 26 | import numpy 27 | from six.moves import urllib 28 | from six.moves import xrange # pylint: disable=redefined-builtin 29 | import tensorflow as tf 30 | from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets 31 | # pylint: enable=unused-import 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple GAN using PyTorch 2 | > This project is a basic Generative Adversarial Network (GAN) implemented in PyTorch on the MNIST Database 3 | 4 | This is one of my initial steps towards GANs in general. This mostly follows from the idea of GAN first published in [arXiv:1406.2661 [stat.ML]](https://arxiv.org/pdf/1406.2661.pdf) by GoodFellow _et.al._ 5 | 6 | ## Getting the code to work 7 | 8 | Follow the instructions below to get our project running on your local machine. 9 | 10 | 1. Clone the repository and make sure you have Python 3 to run the project. 11 | 2. Go to `src/PyTorch/` and run `python gan-mnist-pytorch.py` 12 | 3. All the outputs and related plots can be found in `src/PyTorch/output` folder generated. 13 | 4. The various parameters that can be tweaked before run can be found at `python gan-mnist-pytorch.py --help` 14 | 15 | ### Prerequisites 16 | 17 | * PyTorch 0.4.0 or above 18 | * CUDA 9.1 (or other version corresponding to PyTorch) to utilize any compatible GPU present for faster training 19 | 20 | ### Results 21 | 22 |
23 | Generated Images 24 |
25 | 26 | Images generated by Generator at various epochs - 27 |
28 | 29 | ![Generated on Epoch 1](assets/PyTorch/generated_images/epoch_0.png) | ![Generated on Epoch 40](assets/PyTorch/generated_images/epoch_40.png) | ![Generated on Epoch 80](assets/PyTorch/generated_images/epoch_80.png) 30 | :----------------------------------------------------------------------:|:-------------------------------------------------------------------------:|:--------------------------------------------------------: 31 | Generated on Epoch 0 | Generated on Epoch 40 | Generated on Epoch 80 32 | ![Generated on Epoch 120](assets/PyTorch/generated_images/epoch_120.png) | ![Generated on Epoch 160](assets/PyTorch/generated_images/epoch_160.png) | ![Generated on Epoch 200](assets/PyTorch/generated_images/epoch_200.png) 33 | Generated on Epoch 120 | Generated on Epoch 160 | Generated on Epoch 200 34 | 35 |
36 | 37 | ## Author 38 | 39 | * **Vamsi Krishna Reddy Satti** - [vamsi3](https://github.com/vamsi3) 40 | 41 | ## License 42 | 43 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. -------------------------------------------------------------------------------- /src/PyTorch/gan-mnist-pytorch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | # Importing torch modules 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | 12 | # For MNIST dataset and visualization 13 | from torch.utils.data import DataLoader 14 | from torchvision import datasets 15 | import torchvision.transforms as transforms 16 | from torchvision.utils import save_image 17 | 18 | # Getting the command-line arguments 19 | import argparse 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--epochs', type=int, default=200, help="number of epochs to train") 22 | parser.add_argument('--batch_size', type=int, default=64, help='size of the batches') 23 | parser.add_argument('--learning_rate', type=float, default=0.0002, help='adam: learning rate') 24 | parser.add_argument('--b1', type=float, default=0.5, help='adam: decay rate of first order momentum of gradient') 25 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay rate of second order momentum of gradient') 26 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space') 27 | parser.add_argument('--img_size', type=int, default=28, help='size of each image dimension') 28 | parser.add_argument('--channels', type=int, default=1, help='number of image channels') 29 | parser.add_argument('--output_dir', type=str, default='output', help='name of output directory') 30 | args = parser.parse_args() 31 | 32 | img_shape = (args.channels, args.img_size, args.img_size) 33 | 34 | # Check CUDA's presence 35 | cuda_is_present = True if torch.cuda.is_available() else False 36 | 37 | class Generator(nn.Module): 38 | def __init__(self): 39 | super().__init__() 40 | 41 | def layer_block(input_size, output_size, normalize=True): 42 | layers = [nn.Linear(input_size, output_size)] 43 | if normalize: 44 | layers.append(nn.BatchNorm1d(output_size, 0.8)) 45 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 46 | return layers 47 | 48 | self.model = nn.Sequential( 49 | *layer_block(args.latent_dim, 128, normalize=False), 50 | *layer_block(128, 256), 51 | *layer_block(256, 512), 52 | *layer_block(512, 1024), 53 | nn.Linear(1024, int(np.prod(img_shape))), 54 | nn.Tanh() 55 | ) 56 | 57 | def forward(self, z): 58 | img = self.model(z) 59 | img = img.view(img.size(0), *img_shape) 60 | return img 61 | 62 | class Discriminator(nn.Module): 63 | def __init__(self): 64 | super().__init__() 65 | 66 | self.model = nn.Sequential( 67 | nn.Linear(int(np.prod(img_shape)), 512), 68 | nn.LeakyReLU(0.2, inplace=True), 69 | nn.Linear(512, 256), 70 | nn.LeakyReLU(0.2, inplace=True), 71 | nn.Linear(256, 1), 72 | nn.Sigmoid() 73 | ) 74 | 75 | def forward(self, img): 76 | img_flat = img.view(img.size(0), -1) 77 | verdict = self.model(img_flat) 78 | return verdict 79 | 80 | # Utilize CUDA if available 81 | generator = Generator() 82 | discriminator = Discriminator() 83 | adversarial_loss = torch.nn.BCELoss() 84 | 85 | if cuda_is_present: 86 | generator.cuda() 87 | discriminator.cuda() 88 | adversarial_loss.cuda() 89 | 90 | # Loading MNIST dataset 91 | os.makedirs('data/mnist', exist_ok=True) 92 | data_loader = torch.utils.data.DataLoader( 93 | datasets.MNIST('/data/mnist', train=True, download=True, 94 | transform=transforms.Compose([ 95 | transforms.ToTensor(), 96 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 97 | ])), 98 | batch_size=args.batch_size, shuffle=True) 99 | 100 | # Training the GAN 101 | os.makedirs(f'{args.output_dir}/images', exist_ok=True) 102 | Tensor = torch.cuda.FloatTensor if cuda_is_present else torch.FloatTensor 103 | 104 | optimizer_generator = torch.optim.Adam(generator.parameters(), lr=args.learning_rate, betas=(args.b1, args.b2)) 105 | optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=args.learning_rate, betas=(args.b1, args.b2)) 106 | 107 | losses = [] 108 | images_for_gif = [] 109 | for epoch in range(1, args.epochs+1): 110 | for i, (images, _) in enumerate(data_loader): 111 | 112 | real_images = Variable(images.type(Tensor)) 113 | real_output = Variable(Tensor(images.size(0), 1).fill_(1.0), requires_grad=False) 114 | fake_output = Variable(Tensor(images.size(0), 1).fill_(0.0), requires_grad=False) 115 | 116 | # Training Generator 117 | optimizer_generator.zero_grad() 118 | z = Variable(Tensor(np.random.normal(0, 1, (images.shape[0], args.latent_dim)))) 119 | generated_images = generator(z) 120 | generator_loss = adversarial_loss(discriminator(generated_images), real_output) 121 | generator_loss.backward() 122 | optimizer_generator.step() 123 | 124 | # Training Discriminator 125 | optimizer_discriminator.zero_grad() 126 | discriminator_loss_real = adversarial_loss(discriminator(real_images), real_output) 127 | discriminator_loss_fake = adversarial_loss(discriminator(generated_images.detach()), fake_output) 128 | discriminator_loss = (discriminator_loss_real + discriminator_loss_fake) / 2 129 | discriminator_loss.backward() 130 | optimizer_discriminator.step() 131 | 132 | print(f"[Epoch {epoch:=4d}/{args.epochs}] [Batch {i:=4d}/{len(data_loader)}] ---> " 133 | f"[D Loss: {discriminator_loss.item():.6f}] [G Loss: {generator_loss.item():.6f}]") 134 | 135 | losses.append((generator_loss.item(), discriminator_loss.item())) 136 | image_filename = f'{args.output_dir}/images/{epoch}.png' 137 | save_image(generated_images.data[:25], image_filename, nrow=5, normalize=True) 138 | images_for_gif.append(imageio.imread(image_filename)) 139 | 140 | # Visualizing the losses at every epoch 141 | losses = np.array(losses) 142 | plt.plot(losses.T[0], label='Generator') 143 | plt.plot(losses.T[1], label='Discriminator') 144 | plt.title("Training Losses") 145 | plt.xlabel("Epoch") 146 | plt.ylabel("Loss") 147 | plt.legend() 148 | plt.savefig(f'{args.output_dir}/loss_plot.png') 149 | 150 | # Creating a gif of generated images at every epoch 151 | imageio.mimwrite(f'{args.output_dir}/generated_images.gif', images_for_gif, fps=len(images)/5) 152 | --------------------------------------------------------------------------------