├── .DS_Store
├── README.md
├── data_loader.py
├── default-cifar10
├── .DS_Store
├── genereated_images
│ ├── .DS_Store
│ ├── img_100.png
│ ├── img_120.png
│ ├── img_140.png
│ ├── img_160.png
│ ├── img_180.png
│ ├── img_20.png
│ ├── img_200.png
│ ├── img_40.png
│ ├── img_60.png
│ └── img_80.png
├── loss_graphs
│ └── events.out.tfevents.1567243598.deepak-HP.31654.0
└── training_checkpoints
│ └── .DS_Store
├── default-mnist
├── .DS_Store
├── genereated_images
│ ├── .DS_Store
│ ├── mnist_img_0.png
│ ├── mnist_img_20.png
│ ├── mnist_img_40.png
│ ├── mnist_img_60.png
│ └── mnist_img_80.png
├── loss_graphs
│ ├── .DS_Store
│ └── events.out.tfevents.1567240017.deepak-HP.31233.0
└── training_checkpoints
│ └── .DS_Store
├── model.py
├── requirements.txt
├── resources
├── .DS_Store
├── cifar_disc_loss.png
├── cifar_gen_loss.png
├── cifar_training.gif
├── generated_image_cifar.png
├── generated_image_mnist.png
├── mnist_disc_loss.png
├── mnist_gen_loss.png
└── mnist_training.gif
└── train.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Fully Connected GAN(also known as Vanilla GAN) in Pytorch
3 | This repository contains code of FCGAN which is tested and trained on MNIST and CIFAR-10 datasets. It is based on Pytorch framework.
4 |
5 | 
6 |
7 | ## Generative Adversarial Networks
8 | GANs are generally made up of two models: The Artist (Generator) and The Critic (Discriminator). The generator creates an image from random noise, and the discriminator evaluates the generated image with the images in the given dataset. We train the models by minimaxing the costs of the models. The generator tries to fool the discriminator by producing realistic looking images, and the discriminator becomes better in understanding the difference between real and fake images. This two player game improves the model until the generator produces realistic images or the system reaches nash equilibrium.
9 |
10 | ## Contents
11 | 1. [Setup Instructions and Dependencies](#1-setup-instructions-and-dependencies)
12 | 2. [Training Model from Scratch](#2-training-model-from-scratch)
13 | 3. [Generating Images from Trained Models](#3-generating-images-from-trained-models)
14 | 4. [Model Architecture](#4-model-architecture)
15 | 5. [Repository Overview](#5-repository-overview)
16 | 6. [Results Obtained](#6-results-obtained)
17 | 1. [Generated Images](#i-generated-images)
18 | 2. [Parameters Used](#ii-parameters-used)
19 | 3. [Loss Curves](#iii-loss-curves)
20 | 7. [Observations](#7-observations)
21 | 8. [Credits](#8-credits)
22 |
23 | ## 1. Setup Instructions and Dependencies
24 | You may setup the repository on your local machine by either downloading it or running the following line on `terminal`.
25 |
26 | ``` Batchfile
27 | git clone https://github.com/h3lio5/gan-pytorch.git
28 | ```
29 |
30 | The trained models are large in size and hence their Google Drive links are provided in the `model.txt` file.
31 |
32 | The data required for training is automatically downloaded when running `train.py`.
33 |
34 | All dependencies required by this repo can be downloaded by creating a virtual environment with Python 3.7 and running
35 |
36 | ``` Batchfile
37 | pip install -r requirements.txt
38 | ```
39 |
40 | Make sure to have CUDA 10.0.130 and cuDNN 7.6.0 installed in the virtual environment. For a conda environment, this can be done by using the following commands:
41 |
42 | ```Batchfile
43 | conda install cudatoolkit=10.0
44 | conda install cudnn=7.6.0
45 | ```
46 |
47 | ## 2. Training Model from Scratch
48 | To train your own model from scratch, run
49 |
50 | ```Batchfile
51 | python train.py -config path/to/config.ini
52 | ```
53 |
54 | + The parameters for your experiment are all set by defualt. But you are free to set them on your own.
55 | + The training script will create a folder **exp_name** as specified in your `config.ini` file.
56 | + This folder will contain all data related to your experiment such as tensorboard logs, images generated during training and training checkpoints.
57 |
58 | ## 3. Generating Images from Trained Models
59 | To generate images from trained models, run
60 |
61 | ```Batchfile
62 | python generate.py --dataset mnist/cifar-10 --load_path path/to/checkpoint --grid_size n --save_path directory/where/images/are/saved
63 | ```
64 |
65 | The arguments used are explained as follows
66 |
67 | + `--dataset` requires either `mnist` or `cifar10` according to what dataset the model was trained on.
68 | + `--load_path` requires the path to the training checkpoint to load. Point this towards the *.index file without the extension. For example `-load_path training_checkpoints/ckpt-1`.
69 | + `--grid_size` requires integer `n` and will generate n*n images in a grid.
70 | + `--save_path` requires the path to the directory where the generated images will be saved. If the directory doesn't exist, the script will create it.
71 |
72 | To generate images from pre-trained models, download checkpoint files from the Google Drive link given in the `model.txt` file.
73 |
74 | ## 4. Model Architecture
75 | ### Generator Model
76 | + `MNIST`: The generator model is a 5-layer MLP with LeakyReLU activation function followed by a Tahn non-linearity in the final layer.
77 | + `CIFAR10`: The generator model is a 6-layer MLP with LeakyReLU activation function followed by a Tahn non-linearity in the final layer.
78 | + Input is a 100-dimensional noise. It is passed through the network to produce either a 28x28x1 (MNIST) or 32x32x3 (CIFAR-10) image.
79 |
80 | ### Discriminator Model
81 | + `MNIST`: The discriminator model is a 3-layer MLP with LeakyReLU activation function followed by a Sigmoid non-linearity in the final layer.
82 | + `CIFAR10`: The discriminator model is a 4-layer MLP with LeakyReLU activation function followed by a Sigmoid non-linearity in the final layer.
83 | + Output is a single number which tells if the image is real or fake/generated.
84 |
85 | ## 5. Repository Overview
86 | This repository contains the following files and folders
87 |
88 | 1. **experiments**: This folder contains data for different runs.
89 |
90 | 2. **resources**: Contains media for `readme.md`.
91 |
92 | 3. `data_loader.py`: Contains helper functions that load and preprocess data.
93 |
94 | 4. `generate.py`: Used to generate and save images from trained models.
95 |
96 | 5. `model.py`: Contains helper functions that create generator and discriminator models.
97 |
98 | 6. `model.txt`: Contains google drive links to trained models.
99 |
100 | 7. `requirements.txt`: Lists dependencies for easy setup in virtual environments.
101 |
102 | 8. `train.py`: Contains code to train models from scratch.
103 |
104 | ## 6. Results Obtained
105 | ### i. Generated Images
106 | Samples generated after training model for 100 epochs on MNIST.
107 |
108 | 
109 |
110 | Samples generated after training model for 200 epochs on CIFAR-10.
111 |
112 | 
113 |
114 | ### ii. Parameters Used
115 | + Optimizer used is Adam
116 | + Learning rate 0.0002, beta-1 0.5
117 | + Trained for 100 epochs (MNIST) and 100 epochs (CIFAR10)
118 | + Batch size is 128 for both (MNIST) and (CIFAR10)
119 | + The model uses label flipping (i.e. real images are assigned 0 and fake images are assigned 1)
120 |
121 | ### iii. Loss Curves
122 | #### MNIST
123 | 
124 |
125 | #### CIFAR-10
126 | 
127 |
128 | ## 7. Observations
129 | ### MNIST
130 | The model took around 12 minutes to train for 100 epochs on the gpu. The generated images are not that sharp but somewhat resemble the real data. The model is also prone to mode collapse.
131 |
132 | Training for long duration (150+ epochs) does not seem to improve the model's performance and sometimes even deteriorates the quality of the images produced.
133 |
134 | ### CIFAR-10
135 | Training on the CIFAR10 dataset was challenging. The dataset was varied and the network has a higher number of parameters to train. The model was trained for 200 epochs and took about 30 minutes to train.
136 |
137 | However the main problem faced by me was observing 32x32 images and evaluating if they were 'good enough'. The images are too low-resolution to properly understand the subject but they are easily passable since they look quite similar to the real data.
138 |
139 | Some images have noise but most images don't have much artifacts in them. This is partly due to the network training on all of the 10 labels of the CIFAR-10 dataset. Better results could be obtained by only training the network on one particular label at a time but this takes away the robustness of the model.
140 |
141 | ## 8. Credits
142 | To make this repository I referenced multiple sources:
143 | + [Generative Adversarial Networks — Goodfellow et al. (2014)](https://arxiv.org/abs/1406.2661)
144 | + [MIT intro to deep learning (2019) - Lecture 4](https://www.youtube.com/watch?v=yFBFl1cLYx8)
145 | + [Ian Goodfellow: Generative Adversarial Networks (NIPS 2016 tutorial)](https://www.youtube.com/watch?v=HGYYEUSm-0Q)
146 |
147 |
148 |
149 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import mnist,cifar
2 | import torchvision.transforms as transforms
3 | from torch.utils.data import DataLoader
4 |
5 | class DataLoad():
6 |
7 | def __init__(self):
8 | pass
9 |
10 |
11 | def load_data_mnist(self,batch_size=128):
12 | '''
13 | Returns a nested structure of tensors based on MNIST database.
14 | Will be divided into (60000/batch_size) batches of (batch_size) each.
15 | '''
16 | mnist_data = mnist.MNIST(root='./data/mnist',train=True,download=True,transform=transforms.Compose(
17 | [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]))
18 | mnist_loader = DataLoader(mnist_data,batch_size=batch_size,shuffle=True)
19 | return mnist_loader
20 |
21 | def load_data_cifar10(self,batch_size=128):
22 | '''
23 | Returns a nested structure of tensors based on CIFAR10 database.
24 | Will be divided into (60000/batch_size) batches of (batch_size) each.
25 | '''
26 | cifar_data = cifar.CIFAR10(root='./data/cifar10',train=True,download=True,transform=transforms.Compose(
27 | [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]))
28 | cifar_loader = DataLoader(cifar_data,batch_size=batch_size,shuffle=True)
29 | return cifar_loader
30 |
--------------------------------------------------------------------------------
/default-cifar10/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/.DS_Store
--------------------------------------------------------------------------------
/default-cifar10/genereated_images/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/.DS_Store
--------------------------------------------------------------------------------
/default-cifar10/genereated_images/img_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_100.png
--------------------------------------------------------------------------------
/default-cifar10/genereated_images/img_120.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_120.png
--------------------------------------------------------------------------------
/default-cifar10/genereated_images/img_140.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_140.png
--------------------------------------------------------------------------------
/default-cifar10/genereated_images/img_160.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_160.png
--------------------------------------------------------------------------------
/default-cifar10/genereated_images/img_180.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_180.png
--------------------------------------------------------------------------------
/default-cifar10/genereated_images/img_20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_20.png
--------------------------------------------------------------------------------
/default-cifar10/genereated_images/img_200.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_200.png
--------------------------------------------------------------------------------
/default-cifar10/genereated_images/img_40.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_40.png
--------------------------------------------------------------------------------
/default-cifar10/genereated_images/img_60.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_60.png
--------------------------------------------------------------------------------
/default-cifar10/genereated_images/img_80.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/genereated_images/img_80.png
--------------------------------------------------------------------------------
/default-cifar10/loss_graphs/events.out.tfevents.1567243598.deepak-HP.31654.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/loss_graphs/events.out.tfevents.1567243598.deepak-HP.31654.0
--------------------------------------------------------------------------------
/default-cifar10/training_checkpoints/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-cifar10/training_checkpoints/.DS_Store
--------------------------------------------------------------------------------
/default-mnist/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/.DS_Store
--------------------------------------------------------------------------------
/default-mnist/genereated_images/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/genereated_images/.DS_Store
--------------------------------------------------------------------------------
/default-mnist/genereated_images/mnist_img_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/genereated_images/mnist_img_0.png
--------------------------------------------------------------------------------
/default-mnist/genereated_images/mnist_img_20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/genereated_images/mnist_img_20.png
--------------------------------------------------------------------------------
/default-mnist/genereated_images/mnist_img_40.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/genereated_images/mnist_img_40.png
--------------------------------------------------------------------------------
/default-mnist/genereated_images/mnist_img_60.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/genereated_images/mnist_img_60.png
--------------------------------------------------------------------------------
/default-mnist/genereated_images/mnist_img_80.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/genereated_images/mnist_img_80.png
--------------------------------------------------------------------------------
/default-mnist/loss_graphs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/loss_graphs/.DS_Store
--------------------------------------------------------------------------------
/default-mnist/loss_graphs/events.out.tfevents.1567240017.deepak-HP.31233.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/loss_graphs/events.out.tfevents.1567240017.deepak-HP.31233.0
--------------------------------------------------------------------------------
/default-mnist/training_checkpoints/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/default-mnist/training_checkpoints/.DS_Store
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.functional as F
4 | import numpy as np
5 |
6 |
7 |
8 | class Generator(nn.Module):
9 |
10 | def __init__(self,model_type,z_dim=100):
11 | super(Generator, self).__init__()
12 |
13 | self.model_type = model_type
14 | self.image_shape = {'mnist':(1,28,28),
15 | 'cifar10':(3,32,32)
16 | }
17 | self.models = nn.ModuleDict({
18 | 'mnist': nn.Sequential(
19 | nn.Linear(z_dim,128,bias=True),
20 | nn.LeakyReLU(0.2,inplace=True),
21 | nn.Linear(128,256,bias=True),
22 | nn.LeakyReLU(0.2,inplace=True),
23 | nn.Linear(256,512,bias=True),
24 | nn.LeakyReLU(0.2,inplace=True),
25 | nn.Linear(512,1024,bias=True),
26 | nn.LeakyReLU(0.2,inplace=True),
27 | nn.Linear(1024,int(np.prod(self.image_shape[model_type]))),
28 | nn.Tanh()
29 | ),
30 | 'cifar10':nn.Sequential(
31 | nn.Linear(z_dim,128,bias=True),
32 | nn.LeakyReLU(0.2,inplace=True),
33 | nn.Linear(128,256,bias=True),
34 | nn.LeakyReLU(0.2,inplace=True),
35 | nn.Linear(256,512,bias=True),
36 | nn.LeakyReLU(0.2,inplace=True),
37 | nn.Linear(512,1024,bias=True),
38 | nn.LeakyReLU(0.2,inplace=True),
39 | nn.Linear(1024,2048,bias=True),
40 | nn.LeakyReLU(0.2,inplace=True),
41 | nn.Linear(2048,int(np.prod(self.image_shape[model_type]))),
42 | nn.Tanh()
43 | )
44 | })
45 |
46 | def forward(self, z):
47 | img = self.models[self.model_type](z)
48 | img = img.view(img.size(0), *self.image_shape[self.model_type])
49 | return img
50 |
51 |
52 | class Discriminator(nn.Module):
53 |
54 | def __init__(self,model_type):
55 | super(Discriminator, self).__init__()
56 |
57 | self.model_type = model_type
58 | self.image_shape = {'mnist':(1,28,28),
59 | 'cifar10':(3,32,32)
60 | }
61 | self.models = nn.ModuleDict({
62 | 'mnist':nn.Sequential(
63 | nn.Linear(int(np.prod(self.image_shape[model_type])), 512),
64 | nn.LeakyReLU(0.2, inplace=True),
65 | nn.Linear(512, 256),
66 | nn.LeakyReLU(0.2, inplace=True),
67 | nn.Linear(256, 1),
68 | nn.Sigmoid(),
69 | ),
70 | 'cifar10':nn.Sequential(
71 | nn.Linear(int(np.prod(self.image_shape[model_type])), 1024),
72 | nn.LeakyReLU(0.2, inplace=True),
73 | nn.Linear(1024,512),
74 | nn.LeakyReLU(0.2,inplace=True),
75 | nn.Linear(512, 256),
76 | nn.LeakyReLU(0.2, inplace=True),
77 | nn.Linear(256, 1),
78 | nn.Sigmoid(),
79 | )
80 | })
81 |
82 | def forward(self, img):
83 | img_flat = img.view(img.size(0), -1)
84 | output = self.models[self.model_type](img_flat)
85 | return output
86 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorboard==1.14.0
2 | torch==1.2.0+cu92
3 | torchvision==0.4.0+cu92
4 | numpy==1.17.0
5 |
--------------------------------------------------------------------------------
/resources/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/.DS_Store
--------------------------------------------------------------------------------
/resources/cifar_disc_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/cifar_disc_loss.png
--------------------------------------------------------------------------------
/resources/cifar_gen_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/cifar_gen_loss.png
--------------------------------------------------------------------------------
/resources/cifar_training.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/cifar_training.gif
--------------------------------------------------------------------------------
/resources/generated_image_cifar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/generated_image_cifar.png
--------------------------------------------------------------------------------
/resources/generated_image_mnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/generated_image_mnist.png
--------------------------------------------------------------------------------
/resources/mnist_disc_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/mnist_disc_loss.png
--------------------------------------------------------------------------------
/resources/mnist_gen_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/mnist_gen_loss.png
--------------------------------------------------------------------------------
/resources/mnist_training.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h3lio5/gan-pytorch/4b38951b88528657ac48402e7233ea907bb7fc7a/resources/mnist_training.gif
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from data_loader import DataLoad
3 | from model import *
4 | import torch.nn as nn
5 | from torch.utils import tensorboard
6 | from torch.autograd import Variable
7 | from torchvision.utils import save_image,make_grid
8 | import numpy as np
9 | import matplotlib.pyplot as plt
10 | import os
11 | import argparse
12 |
13 |
14 | cuda = True if torch.cuda.is_available() else False
15 | device = 'cuda' if cuda else 'cpu'
16 | k = 4
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--batch_size', default=128,type=int,help='Enter the batch size')
19 | parser.add_argument('--total_epochs',default=100,type=int,help='Enter the total number of epochs')
20 | parser.add_argument('--dataset',default='mnist',help='Enter the dataset you want the model to train on')
21 | parser.add_argument('--model_save_frequency',default=20,type=int,help='How often do you want to save the model state')
22 | parser.add_argument('--image_sample_frequency',default=20,type=int,help='How often do you want to sample images ')
23 | parser.add_argument('--learning_rate',default=0.0002,type=int)
24 | parser.add_argument('--beta1',default=0.5,type=int,help='beta1 parameter for adam optimizer')
25 | parser.add_argument('--beta2',default=0.999,type=int,help='beta2 parameter for adam optimizer')
26 | parser.add_argument('--z_dim',default=100,type=int,help='Enter the dimension of the noise vector')
27 | parser.add_argument('--exp_name',default='default-mnist',help='Enter the name of the experiment')
28 | args = parser.parse_args()
29 |
30 | fixed_noise = torch.randn(16,args.z_dim,device=device)
31 |
32 |
33 | #Create the experiment folder
34 | if not os.path.exists(args.exp_name):
35 | os.makedirs(args.exp_name)
36 |
37 | def load_data(use_data):
38 | # Initialize the data loader object
39 | data_loader = DataLoad()
40 | # Load training data into the dataloader
41 | if use_data == 'mnist':
42 | train_loader = data_loader.load_data_mnist(batch_size=args.batch_size)
43 | elif use_data == 'cifar10':
44 | train_loader = data_loader.load_data_cifar10(batch_size=args.batch_size)
45 | # Return the data loader for the training set
46 | return train_loader
47 |
48 | def save_checkpoint(state,dirpath, epoch):
49 | #Save the model in the specified folder
50 | folder_path = dirpath+'/training_checkpoints'
51 | if not os.path.exists(folder_path):
52 | os.makedirs(folder_path)
53 | filename = '{}-checkpoint-{}.ckpt'.format(args.dataset,epoch)
54 | checkpoint_path = os.path.join(folder_path, filename)
55 | torch.save(state, checkpoint_path)
56 | print(' checkpoint saved to {} '.format(checkpoint_path))
57 |
58 | def generate_image(fakes,image_folder):
59 | #Function to generate image grid and save
60 | image_grid = make_grid(fakes.to(device),padding=2,nrow=4,normalize=True)
61 | if not os.path.exists(image_folder):
62 | os.makedirs(image_folder)
63 | save_image(image_grid,filename='{}/img_{}.png'.format(image_folder,epoch))
64 |
65 | # Loss function
66 | criterion = torch.nn.BCELoss()
67 |
68 | # Initialize generator and discriminator
69 | generator = Generator(args.dataset)
70 | discriminator = Discriminator(args.dataset)
71 |
72 | if cuda:
73 | generator.cuda()
74 | discriminator.cuda()
75 | criterion.cuda()
76 | # Optimizers
77 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.learning_rate, betas=(args.beta1,args.beta2))
78 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.learning_rate, betas=(args.beta1,args.beta2))
79 |
80 | # Establish convention for real and fake labels during training
81 | real_label = float(1)
82 | fake_label = float(0)
83 |
84 | # Load training data
85 | train_loader = load_data(args.dataset)
86 |
87 | # Training Loop
88 | # Lists to keep track of progress
89 | # Create the runs directory if it does not exist
90 | if not os.path.exists(args.exp_name+'/tensorboard_logs'):
91 | os.makedirs(args.exp_name+'/tensorboard_logs')
92 | writer = tensorboard.SummaryWriter(log_dir=args.exp_name+'/tensorboard_logs')
93 | print("Starting Training Loop...")
94 | steps = 0
95 | # For each epoch
96 | for epoch in range(args.total_epochs):
97 | # Update the discriminator k times before updating generator as specified in the paper
98 | for i, (imgs, _) in enumerate(train_loader):
99 |
100 | ############################
101 | # (1) Update discriminator network: maximize log(D(x)) + log(1 - D(G(z)))
102 | ###########################
103 | ## Train with all-real batch
104 | # Format batch
105 | imgs = imgs.to(device)
106 | # Adversarial ground truths
107 | valid = Variable(torch.Tensor(imgs.size(0),1).fill_(real_label), requires_grad=False).to(device)
108 | fake = Variable(torch.Tensor(imgs.size(0),1).fill_(fake_label), requires_grad=False).to(device)
109 | optimizer_D.zero_grad()
110 | # Calculate loss on all-real batch
111 | real_loss = criterion(discriminator(imgs), valid)
112 | # Generate batch of latent vectors
113 | noise = Variable(torch.Tensor(np.random.normal(0, 1, (imgs.shape[0], args.z_dim)))).to(device)
114 | # Generate fake image batch with generator
115 | gen_imgs = generator(noise)
116 | # Classify all fake batch with D
117 | # Calculate D's loss on the all-fake batch
118 | fake_loss = criterion(discriminator(gen_imgs.detach()), fake)
119 | # Add the gradients from the all-real and all-fake batches
120 | loss_D = real_loss + fake_loss
121 | # Calculate the gradients
122 | loss_D.backward()
123 | #Update D
124 | optimizer_D.step()
125 |
126 | ############################
127 | # (2) Update G network: maximize log(D(G(z)))
128 | # Optimize the generator network only after k steps of optimizing discriminator as
129 | # specified in the paper. This is done to ensure that the discriminator is being maintained
130 | # near its optimal solution as long as generator changes slowly enough.
131 | # Go through the Adversarial nets section in the paper
132 | # for detailed explanation (https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf)
133 | ###########################
134 | if (epoch+1)%k == 0:
135 |
136 | optimizer_G.zero_grad()
137 | # fake labels are real for generator cost
138 | # Since we just updated D, perform another forward pass of all-fake batch through D
139 | gen_imgs = generator(noise)
140 | output = discriminator(gen_imgs)
141 | # Calculate the probability of the discriminator to classify fake images as real.
142 | # If the value of this probability is close to 0, then it means that the generator has
143 | # successfully learnt to fool the discriminator
144 | D_x = output.mean().item()
145 | # Calculate G's loss based on this output
146 | loss_G = criterion(output, valid)
147 | # Calculate gradients for G
148 | loss_G.backward()
149 | # Update G
150 | optimizer_G.step()
151 | # Output training stats
152 | print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\t'
153 | % (epoch+1, args.total_epochs, i+1, len(train_loader),
154 | loss_D.item(), loss_G.item(), D_x))
155 |
156 | writer.add_scalar('D_x',D_x,steps)
157 | writer.add_scalar('Discriminator_loss',loss_D,steps)
158 | writer.add_scalar('Generator_loss',loss_G,steps)
159 | steps+=1
160 |
161 | if (epoch+1) % args.model_save_frequency == 0:
162 | # Saved the model and optimizer states
163 | save_checkpoint({
164 | 'epoch': epoch + 1,
165 | 'generator': generator.state_dict(),
166 | 'discriminator': discriminator.state_dict(),
167 | 'optimizer_G' : optimizer_G.state_dict(),
168 | 'optimizer_D' : optimizer_D.state_dict(),
169 | }, args.exp_name, epoch + 1)
170 | # Generate images from the generator network
171 | if epoch % args.image_sample_frequency == 0:
172 | with torch.no_grad():
173 | fakes = generator(fixed_noise)
174 | image_folder = args.exp_name + '/genereated_images'
175 | generate_images(fakes,image_folder)
176 | writer.close()
177 |
--------------------------------------------------------------------------------