├── .DS_Store ├── README.md ├── data_loader.py ├── default-cifar10 ├── .DS_Store ├── genereated_images │ ├── .DS_Store │ ├── img_100.png │ ├── img_120.png │ ├── img_140.png │ ├── img_160.png │ ├── img_180.png │ ├── img_20.png │ ├── img_200.png │ ├── img_40.png │ ├── img_60.png │ └── img_80.png ├── loss_graphs │ └── events.out.tfevents.1567243598.deepak-HP.31654.0 └── training_checkpoints │ └── .DS_Store ├── default-mnist ├── .DS_Store ├── genereated_images │ ├── .DS_Store │ ├── mnist_img_0.png │ ├── mnist_img_20.png │ ├── mnist_img_40.png │ ├── mnist_img_60.png │ └── mnist_img_80.png ├── loss_graphs │ ├── .DS_Store │ └── events.out.tfevents.1567240017.deepak-HP.31233.0 └── training_checkpoints │ └── .DS_Store ├── model.py ├── requirements.txt ├── resources ├── .DS_Store ├── cifar_disc_loss.png ├── cifar_gen_loss.png ├── cifar_training.gif ├── generated_image_cifar.png ├── generated_image_mnist.png ├── mnist_disc_loss.png ├── mnist_gen_loss.png └── mnist_training.gif └── train.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Fully Connected GAN(also known as Vanilla GAN) in Pytorch 3 | This repository contains code of FCGAN which is tested and trained on MNIST and CIFAR-10 datasets. It is based on Pytorch framework. 4 | 5 | 6 | 7 | ## Generative Adversarial Networks 8 | GANs are generally made up of two models: The Artist (Generator) and The Critic (Discriminator). The generator creates an image from random noise, and the discriminator evaluates the generated image with the images in the given dataset. We train the models by minimaxing the costs of the models. The generator tries to fool the discriminator by producing realistic looking images, and the discriminator becomes better in understanding the difference between real and fake images. This two player game improves the model until the generator produces realistic images or the system reaches nash equilibrium. 9 | 10 | ## Contents 11 | 1. [Setup Instructions and Dependencies](#1-setup-instructions-and-dependencies) 12 | 2. [Training Model from Scratch](#2-training-model-from-scratch) 13 | 3. [Generating Images from Trained Models](#3-generating-images-from-trained-models) 14 | 4. [Model Architecture](#4-model-architecture) 15 | 5. [Repository Overview](#5-repository-overview) 16 | 6. [Results Obtained](#6-results-obtained) 17 | 1. [Generated Images](#i-generated-images) 18 | 2. [Parameters Used](#ii-parameters-used) 19 | 3. [Loss Curves](#iii-loss-curves) 20 | 7. [Observations](#7-observations) 21 | 8. [Credits](#8-credits) 22 | 23 | ## 1. Setup Instructions and Dependencies 24 | You may setup the repository on your local machine by either downloading it or running the following line on `terminal`. 25 | 26 | ``` Batchfile 27 | git clone https://github.com/h3lio5/gan-pytorch.git 28 | ``` 29 | 30 | The trained models are large in size and hence their Google Drive links are provided in the `model.txt` file. 31 | 32 | The data required for training is automatically downloaded when running `train.py`. 33 | 34 | All dependencies required by this repo can be downloaded by creating a virtual environment with Python 3.7 and running 35 | 36 | ``` Batchfile 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | Make sure to have CUDA 10.0.130 and cuDNN 7.6.0 installed in the virtual environment. For a conda environment, this can be done by using the following commands: 41 | 42 | ```Batchfile 43 | conda install cudatoolkit=10.0 44 | conda install cudnn=7.6.0 45 | ``` 46 | 47 | ## 2. Training Model from Scratch 48 | To train your own model from scratch, run 49 | 50 | ```Batchfile 51 | python train.py -config path/to/config.ini 52 | ``` 53 | 54 | + The parameters for your experiment are all set by defualt. But you are free to set them on your own. 55 | + The training script will create a folder **exp_name** as specified in your `config.ini` file. 56 | + This folder will contain all data related to your experiment such as tensorboard logs, images generated during training and training checkpoints. 57 | 58 | ## 3. Generating Images from Trained Models 59 | To generate images from trained models, run 60 | 61 | ```Batchfile 62 | python generate.py --dataset mnist/cifar-10 --load_path path/to/checkpoint --grid_size n --save_path directory/where/images/are/saved 63 | ``` 64 | 65 | The arguments used are explained as follows 66 | 67 | + `--dataset` requires either `mnist` or `cifar10` according to what dataset the model was trained on. 68 | + `--load_path` requires the path to the training checkpoint to load. Point this towards the *.index file without the extension. For example `-load_path training_checkpoints/ckpt-1`. 69 | + `--grid_size` requires integer `n` and will generate n*n images in a grid. 70 | + `--save_path` requires the path to the directory where the generated images will be saved. If the directory doesn't exist, the script will create it. 71 | 72 | To generate images from pre-trained models, download checkpoint files from the Google Drive link given in the `model.txt` file. 73 | 74 | ## 4. Model Architecture 75 | ### Generator Model 76 | + `MNIST`: The generator model is a 5-layer MLP with LeakyReLU activation function followed by a Tahn non-linearity in the final layer. 77 | + `CIFAR10`: The generator model is a 6-layer MLP with LeakyReLU activation function followed by a Tahn non-linearity in the final layer. 78 | + Input is a 100-dimensional noise. It is passed through the network to produce either a 28x28x1 (MNIST) or 32x32x3 (CIFAR-10) image. 79 | 80 | ### Discriminator Model 81 | + `MNIST`: The discriminator model is a 3-layer MLP with LeakyReLU activation function followed by a Sigmoid non-linearity in the final layer. 82 | + `CIFAR10`: The discriminator model is a 4-layer MLP with LeakyReLU activation function followed by a Sigmoid non-linearity in the final layer. 83 | + Output is a single number which tells if the image is real or fake/generated. 84 | 85 | ## 5. Repository Overview 86 | This repository contains the following files and folders 87 | 88 | 1. **experiments**: This folder contains data for different runs. 89 | 90 | 2. **resources**: Contains media for `readme.md`. 91 | 92 | 3. `data_loader.py`: Contains helper functions that load and preprocess data. 93 | 94 | 4. `generate.py`: Used to generate and save images from trained models. 95 | 96 | 5. `model.py`: Contains helper functions that create generator and discriminator models. 97 | 98 | 6. `model.txt`: Contains google drive links to trained models. 99 | 100 | 7. `requirements.txt`: Lists dependencies for easy setup in virtual environments. 101 | 102 | 8. `train.py`: Contains code to train models from scratch. 103 | 104 | ## 6. Results Obtained 105 | ### i. Generated Images 106 | Samples generated after training model for 100 epochs on MNIST. 107 | 108 | ![mnist_generated](resources/generated_image_mnist.png) 109 | 110 | Samples generated after training model for 200 epochs on CIFAR-10. 111 | 112 | ![cifar_generated](resources/generated_image_cifar.png) 113 | 114 | ### ii. Parameters Used 115 | + Optimizer used is Adam 116 | + Learning rate 0.0002, beta-1 0.5 117 | + Trained for 100 epochs (MNIST) and 100 epochs (CIFAR10) 118 | + Batch size is 128 for both (MNIST) and (CIFAR10) 119 | + The model uses label flipping (i.e. real images are assigned 0 and fake images are assigned 1) 120 | 121 | ### iii. Loss Curves 122 | #### MNIST 123 | 124 | 125 | #### CIFAR-10 126 | 127 | 128 | ## 7. Observations 129 | ### MNIST 130 | The model took around 12 minutes to train for 100 epochs on the gpu. The generated images are not that sharp but somewhat resemble the real data. The model is also prone to mode collapse. 131 | 132 | Training for long duration (150+ epochs) does not seem to improve the model's performance and sometimes even deteriorates the quality of the images produced. 133 | 134 | ### CIFAR-10 135 | Training on the CIFAR10 dataset was challenging. The dataset was varied and the network has a higher number of parameters to train. The model was trained for 200 epochs and took about 30 minutes to train. 136 | 137 | However the main problem faced by me was observing 32x32 images and evaluating if they were 'good enough'. The images are too low-resolution to properly understand the subject but they are easily passable since they look quite similar to the real data. 138 | 139 | Some images have noise but most images don't have much artifacts in them. This is partly due to the network training on all of the 10 labels of the CIFAR-10 dataset. Better results could be obtained by only training the network on one particular label at a time but this takes away the robustness of the model. 140 | 141 | ## 8. Credits 142 | To make this repository I referenced multiple sources: 143 | + [Generative Adversarial Networks — Goodfellow et al. (2014)](https://arxiv.org/abs/1406.2661) 144 | + [MIT intro to deep learning (2019) - Lecture 4](https://www.youtube.com/watch?v=yFBFl1cLYx8) 145 | + [Ian Goodfellow: Generative Adversarial Networks (NIPS 2016 tutorial)](https://www.youtube.com/watch?v=HGYYEUSm-0Q) 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import mnist,cifar 2 | import torchvision.transforms as transforms 3 | from torch.utils.data import DataLoader 4 | 5 | class DataLoad(): 6 | 7 | def __init__(self): 8 | pass 9 | 10 | 11 | def load_data_mnist(self,batch_size=128): 12 | ''' 13 | Returns a nested structure of tensors based on MNIST database. 14 | Will be divided into (60000/batch_size) batches of (batch_size) each. 15 | ''' 16 | mnist_data = mnist.MNIST(root='./data/mnist',train=True,download=True,transform=transforms.Compose( 17 | [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])) 18 | mnist_loader = DataLoader(mnist_data,batch_size=batch_size,shuffle=True) 19 | return mnist_loader 20 | 21 | def load_data_cifar10(self,batch_size=128): 22 | ''' 23 | Returns a nested structure of tensors based on CIFAR10 database. 24 | Will be divided into (60000/batch_size) batches of (batch_size) each. 25 | ''' 26 | cifar_data = cifar.CIFAR10(root='./data/cifar10',train=True,download=True,transform=transforms.Compose( 27 | [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])) 28 | cifar_loader = DataLoader(cifar_data,batch_size=batch_size,shuffle=True) 29 | return cifar_loader 30 | -------------------------------------------------------------------------------- /default-cifar10/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/.DS_Store -------------------------------------------------------------------------------- /default-cifar10/genereated_images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/.DS_Store -------------------------------------------------------------------------------- /default-cifar10/genereated_images/img_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_100.png -------------------------------------------------------------------------------- /default-cifar10/genereated_images/img_120.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_120.png -------------------------------------------------------------------------------- /default-cifar10/genereated_images/img_140.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_140.png -------------------------------------------------------------------------------- /default-cifar10/genereated_images/img_160.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_160.png -------------------------------------------------------------------------------- /default-cifar10/genereated_images/img_180.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_180.png -------------------------------------------------------------------------------- /default-cifar10/genereated_images/img_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_20.png -------------------------------------------------------------------------------- /default-cifar10/genereated_images/img_200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_200.png -------------------------------------------------------------------------------- /default-cifar10/genereated_images/img_40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_40.png -------------------------------------------------------------------------------- /default-cifar10/genereated_images/img_60.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_60.png -------------------------------------------------------------------------------- /default-cifar10/genereated_images/img_80.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_80.png -------------------------------------------------------------------------------- /default-cifar10/loss_graphs/events.out.tfevents.1567243598.deepak-HP.31654.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/loss_graphs/events.out.tfevents.1567243598.deepak-HP.31654.0 -------------------------------------------------------------------------------- /default-cifar10/training_checkpoints/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/training_checkpoints/.DS_Store -------------------------------------------------------------------------------- /default-mnist/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/.DS_Store -------------------------------------------------------------------------------- /default-mnist/genereated_images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/genereated_images/.DS_Store -------------------------------------------------------------------------------- /default-mnist/genereated_images/mnist_img_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/genereated_images/mnist_img_0.png -------------------------------------------------------------------------------- /default-mnist/genereated_images/mnist_img_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/genereated_images/mnist_img_20.png -------------------------------------------------------------------------------- /default-mnist/genereated_images/mnist_img_40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/genereated_images/mnist_img_40.png -------------------------------------------------------------------------------- /default-mnist/genereated_images/mnist_img_60.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/genereated_images/mnist_img_60.png -------------------------------------------------------------------------------- /default-mnist/genereated_images/mnist_img_80.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/genereated_images/mnist_img_80.png -------------------------------------------------------------------------------- /default-mnist/loss_graphs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/loss_graphs/.DS_Store -------------------------------------------------------------------------------- /default-mnist/loss_graphs/events.out.tfevents.1567240017.deepak-HP.31233.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/loss_graphs/events.out.tfevents.1567240017.deepak-HP.31233.0 -------------------------------------------------------------------------------- /default-mnist/training_checkpoints/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/training_checkpoints/.DS_Store -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.functional as F 4 | import numpy as np 5 | 6 | 7 | 8 | class Generator(nn.Module): 9 | 10 | def __init__(self,model_type,z_dim=100): 11 | super(Generator, self).__init__() 12 | 13 | self.model_type = model_type 14 | self.image_shape = {'mnist':(1,28,28), 15 | 'cifar10':(3,32,32) 16 | } 17 | self.models = nn.ModuleDict({ 18 | 'mnist': nn.Sequential( 19 | nn.Linear(z_dim,128,bias=True), 20 | nn.LeakyReLU(0.2,inplace=True), 21 | nn.Linear(128,256,bias=True), 22 | nn.LeakyReLU(0.2,inplace=True), 23 | nn.Linear(256,512,bias=True), 24 | nn.LeakyReLU(0.2,inplace=True), 25 | nn.Linear(512,1024,bias=True), 26 | nn.LeakyReLU(0.2,inplace=True), 27 | nn.Linear(1024,int(np.prod(self.image_shape[model_type]))), 28 | nn.Tanh() 29 | ), 30 | 'cifar10':nn.Sequential( 31 | nn.Linear(z_dim,128,bias=True), 32 | nn.LeakyReLU(0.2,inplace=True), 33 | nn.Linear(128,256,bias=True), 34 | nn.LeakyReLU(0.2,inplace=True), 35 | nn.Linear(256,512,bias=True), 36 | nn.LeakyReLU(0.2,inplace=True), 37 | nn.Linear(512,1024,bias=True), 38 | nn.LeakyReLU(0.2,inplace=True), 39 | nn.Linear(1024,2048,bias=True), 40 | nn.LeakyReLU(0.2,inplace=True), 41 | nn.Linear(2048,int(np.prod(self.image_shape[model_type]))), 42 | nn.Tanh() 43 | ) 44 | }) 45 | 46 | def forward(self, z): 47 | img = self.models[self.model_type](z) 48 | img = img.view(img.size(0), *self.image_shape[self.model_type]) 49 | return img 50 | 51 | 52 | class Discriminator(nn.Module): 53 | 54 | def __init__(self,model_type): 55 | super(Discriminator, self).__init__() 56 | 57 | self.model_type = model_type 58 | self.image_shape = {'mnist':(1,28,28), 59 | 'cifar10':(3,32,32) 60 | } 61 | self.models = nn.ModuleDict({ 62 | 'mnist':nn.Sequential( 63 | nn.Linear(int(np.prod(self.image_shape[model_type])), 512), 64 | nn.LeakyReLU(0.2, inplace=True), 65 | nn.Linear(512, 256), 66 | nn.LeakyReLU(0.2, inplace=True), 67 | nn.Linear(256, 1), 68 | nn.Sigmoid(), 69 | ), 70 | 'cifar10':nn.Sequential( 71 | nn.Linear(int(np.prod(self.image_shape[model_type])), 1024), 72 | nn.LeakyReLU(0.2, inplace=True), 73 | nn.Linear(1024,512), 74 | nn.LeakyReLU(0.2,inplace=True), 75 | nn.Linear(512, 256), 76 | nn.LeakyReLU(0.2, inplace=True), 77 | nn.Linear(256, 1), 78 | nn.Sigmoid(), 79 | ) 80 | }) 81 | 82 | def forward(self, img): 83 | img_flat = img.view(img.size(0), -1) 84 | output = self.models[self.model_type](img_flat) 85 | return output 86 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboard==1.14.0 2 | torch==1.2.0+cu92 3 | torchvision==0.4.0+cu92 4 | numpy==1.17.0 5 | -------------------------------------------------------------------------------- /resources/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/.DS_Store -------------------------------------------------------------------------------- /resources/cifar_disc_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/cifar_disc_loss.png -------------------------------------------------------------------------------- /resources/cifar_gen_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/cifar_gen_loss.png -------------------------------------------------------------------------------- /resources/cifar_training.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/cifar_training.gif -------------------------------------------------------------------------------- /resources/generated_image_cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/generated_image_cifar.png -------------------------------------------------------------------------------- /resources/generated_image_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/generated_image_mnist.png -------------------------------------------------------------------------------- /resources/mnist_disc_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/mnist_disc_loss.png -------------------------------------------------------------------------------- /resources/mnist_gen_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/mnist_gen_loss.png -------------------------------------------------------------------------------- /resources/mnist_training.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/mnist_training.gif -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data_loader import DataLoad 3 | from model import * 4 | import torch.nn as nn 5 | from torch.utils import tensorboard 6 | from torch.autograd import Variable 7 | from torchvision.utils import save_image,make_grid 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import os 11 | import argparse 12 | 13 | 14 | cuda = True if torch.cuda.is_available() else False 15 | device = 'cuda' if cuda else 'cpu' 16 | k = 4 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--batch_size', default=128,type=int,help='Enter the batch size') 19 | parser.add_argument('--total_epochs',default=100,type=int,help='Enter the total number of epochs') 20 | parser.add_argument('--dataset',default='mnist',help='Enter the dataset you want the model to train on') 21 | parser.add_argument('--model_save_frequency',default=20,type=int,help='How often do you want to save the model state') 22 | parser.add_argument('--image_sample_frequency',default=20,type=int,help='How often do you want to sample images ') 23 | parser.add_argument('--learning_rate',default=0.0002,type=int) 24 | parser.add_argument('--beta1',default=0.5,type=int,help='beta1 parameter for adam optimizer') 25 | parser.add_argument('--beta2',default=0.999,type=int,help='beta2 parameter for adam optimizer') 26 | parser.add_argument('--z_dim',default=100,type=int,help='Enter the dimension of the noise vector') 27 | parser.add_argument('--exp_name',default='default-mnist',help='Enter the name of the experiment') 28 | args = parser.parse_args() 29 | 30 | fixed_noise = torch.randn(16,args.z_dim,device=device) 31 | 32 | 33 | #Create the experiment folder 34 | if not os.path.exists(args.exp_name): 35 | os.makedirs(args.exp_name) 36 | 37 | def load_data(use_data): 38 | # Initialize the data loader object 39 | data_loader = DataLoad() 40 | # Load training data into the dataloader 41 | if use_data == 'mnist': 42 | train_loader = data_loader.load_data_mnist(batch_size=args.batch_size) 43 | elif use_data == 'cifar10': 44 | train_loader = data_loader.load_data_cifar10(batch_size=args.batch_size) 45 | # Return the data loader for the training set 46 | return train_loader 47 | 48 | def save_checkpoint(state,dirpath, epoch): 49 | #Save the model in the specified folder 50 | folder_path = dirpath+'/training_checkpoints' 51 | if not os.path.exists(folder_path): 52 | os.makedirs(folder_path) 53 | filename = '{}-checkpoint-{}.ckpt'.format(args.dataset,epoch) 54 | checkpoint_path = os.path.join(folder_path, filename) 55 | torch.save(state, checkpoint_path) 56 | print(' checkpoint saved to {} '.format(checkpoint_path)) 57 | 58 | def generate_image(fakes,image_folder): 59 | #Function to generate image grid and save 60 | image_grid = make_grid(fakes.to(device),padding=2,nrow=4,normalize=True) 61 | if not os.path.exists(image_folder): 62 | os.makedirs(image_folder) 63 | save_image(image_grid,filename='{}/img_{}.png'.format(image_folder,epoch)) 64 | 65 | # Loss function 66 | criterion = torch.nn.BCELoss() 67 | 68 | # Initialize generator and discriminator 69 | generator = Generator(args.dataset) 70 | discriminator = Discriminator(args.dataset) 71 | 72 | if cuda: 73 | generator.cuda() 74 | discriminator.cuda() 75 | criterion.cuda() 76 | # Optimizers 77 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.learning_rate, betas=(args.beta1,args.beta2)) 78 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.learning_rate, betas=(args.beta1,args.beta2)) 79 | 80 | # Establish convention for real and fake labels during training 81 | real_label = float(1) 82 | fake_label = float(0) 83 | 84 | # Load training data 85 | train_loader = load_data(args.dataset) 86 | 87 | # Training Loop 88 | # Lists to keep track of progress 89 | # Create the runs directory if it does not exist 90 | if not os.path.exists(args.exp_name+'/tensorboard_logs'): 91 | os.makedirs(args.exp_name+'/tensorboard_logs') 92 | writer = tensorboard.SummaryWriter(log_dir=args.exp_name+'/tensorboard_logs') 93 | print("Starting Training Loop...") 94 | steps = 0 95 | # For each epoch 96 | for epoch in range(args.total_epochs): 97 | # Update the discriminator k times before updating generator as specified in the paper 98 | for i, (imgs, _) in enumerate(train_loader): 99 | 100 | ############################ 101 | # (1) Update discriminator network: maximize log(D(x)) + log(1 - D(G(z))) 102 | ########################### 103 | ## Train with all-real batch 104 | # Format batch 105 | imgs = imgs.to(device) 106 | # Adversarial ground truths 107 | valid = Variable(torch.Tensor(imgs.size(0),1).fill_(real_label), requires_grad=False).to(device) 108 | fake = Variable(torch.Tensor(imgs.size(0),1).fill_(fake_label), requires_grad=False).to(device) 109 | optimizer_D.zero_grad() 110 | # Calculate loss on all-real batch 111 | real_loss = criterion(discriminator(imgs), valid) 112 | # Generate batch of latent vectors 113 | noise = Variable(torch.Tensor(np.random.normal(0, 1, (imgs.shape[0], args.z_dim)))).to(device) 114 | # Generate fake image batch with generator 115 | gen_imgs = generator(noise) 116 | # Classify all fake batch with D 117 | # Calculate D's loss on the all-fake batch 118 | fake_loss = criterion(discriminator(gen_imgs.detach()), fake) 119 | # Add the gradients from the all-real and all-fake batches 120 | loss_D = real_loss + fake_loss 121 | # Calculate the gradients 122 | loss_D.backward() 123 | #Update D 124 | optimizer_D.step() 125 | 126 | ############################ 127 | # (2) Update G network: maximize log(D(G(z))) 128 | # Optimize the generator network only after k steps of optimizing discriminator as 129 | # specified in the paper. This is done to ensure that the discriminator is being maintained 130 | # near its optimal solution as long as generator changes slowly enough. 131 | # Go through the Adversarial nets section in the paper 132 | # for detailed explanation (https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf) 133 | ########################### 134 | if (epoch+1)%k == 0: 135 | 136 | optimizer_G.zero_grad() 137 | # fake labels are real for generator cost 138 | # Since we just updated D, perform another forward pass of all-fake batch through D 139 | gen_imgs = generator(noise) 140 | output = discriminator(gen_imgs) 141 | # Calculate the probability of the discriminator to classify fake images as real. 142 | # If the value of this probability is close to 0, then it means that the generator has 143 | # successfully learnt to fool the discriminator 144 | D_x = output.mean().item() 145 | # Calculate G's loss based on this output 146 | loss_G = criterion(output, valid) 147 | # Calculate gradients for G 148 | loss_G.backward() 149 | # Update G 150 | optimizer_G.step() 151 | # Output training stats 152 | print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\t' 153 | % (epoch+1, args.total_epochs, i+1, len(train_loader), 154 | loss_D.item(), loss_G.item(), D_x)) 155 | 156 | writer.add_scalar('D_x',D_x,steps) 157 | writer.add_scalar('Discriminator_loss',loss_D,steps) 158 | writer.add_scalar('Generator_loss',loss_G,steps) 159 | steps+=1 160 | 161 | if (epoch+1) % args.model_save_frequency == 0: 162 | # Saved the model and optimizer states 163 | save_checkpoint({ 164 | 'epoch': epoch + 1, 165 | 'generator': generator.state_dict(), 166 | 'discriminator': discriminator.state_dict(), 167 | 'optimizer_G' : optimizer_G.state_dict(), 168 | 'optimizer_D' : optimizer_D.state_dict(), 169 | }, args.exp_name, epoch + 1) 170 | # Generate images from the generator network 171 | if epoch % args.image_sample_frequency == 0: 172 | with torch.no_grad(): 173 | fakes = generator(fixed_noise) 174 | image_folder = args.exp_name + '/genereated_images' 175 | generate_images(fakes,image_folder) 176 | writer.close() 177 | --------------------------------------------------------------------------------