├── .gitignore ├── LICENSE ├── MNIST_data ├── t10k-images-idx3-ubyte.gz ├── t10k-labels-idx1-ubyte.gz ├── train-images-idx3-ubyte.gz └── train-labels-idx1-ubyte.gz ├── README.md ├── gan-notebook.ipynb ├── gan-script-fast.py ├── gan-script.py ├── notebook-images ├── GAN_Discriminator.png ├── GAN_Generator.png ├── GAN_Overall.png └── gan-animation.gif └── pretrained-model ├── checkpoint ├── pretrained_gan.ckpt.data-00000-of-00001 ├── pretrained_gan.ckpt.index └── pretrained_gan.ckpt.meta /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | tensorboard/ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Jon Bruner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MNIST_data/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonbruner/generative-adversarial-networks/2e792d92b823f5a2c9c8095420869a539aa0819c/MNIST_data/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /MNIST_data/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonbruner/generative-adversarial-networks/2e792d92b823f5a2c9c8095420869a539aa0819c/MNIST_data/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /MNIST_data/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonbruner/generative-adversarial-networks/2e792d92b823f5a2c9c8095420869a539aa0819c/MNIST_data/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /MNIST_data/train-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonbruner/generative-adversarial-networks/2e792d92b823f5a2c9c8095420869a539aa0819c/MNIST_data/train-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction to generative adversarial networks 2 | 3 | This repository contains code to accompany [the O'Reilly tutorial on generative adversarial networks](https://www.oreilly.com/learning/generative-adversarial-networks-for-beginners) written by [Jon Bruner](https://github.com/jonbruner) and [Adit Deshpande](https://github.com/adeshpande3). See [the original tutorial](https://www.oreilly.com/learning/generative-adversarial-networks-for-beginners) to run this code in a pre-built environment on O'Reilly's servers with cell-by-cell guidance, or run these files on your own machine. 4 | 5 | There are three versions of our simple GAN model in this repository: 6 | - **[gan-notebook.ipynb](gan-notebook.ipynb)** is identical to the interactive tutorial, available here so that you can run it on your own machine. 7 | - **[gan-script.py](gan-script.py)** is a straightforward Python script containing code drawn directly from the tutorial, to be run from the command line. Note that it doesn't print anything when it's executed, but it does send regular updates to [TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard) so that you can track its progress. 8 | - **[gan-script-fast.py](gan-script-fast.py)** is a modest refactoring of gan-script.py that runs slightly faster because more of its computations are contained in the TensorFlow graph. 9 | 10 | ## Requirements and installation 11 | In order to run [gan-script.py](gan-script.py) or [gan-script-fast.py](gan-script-fast.py), you'll need **[TensorFlow](https://www.tensorflow.org/install/) version 1.0 or later** and [NumPy](https://docs.scipy.org/doc/numpy/user/install.html). In order to run [gan-notebook.ipynb](gan-notebook.ipynb), you'll additionally need [Jupyter](https://jupyter.readthedocs.io/en/latest/install.html) and [matplotlib](https://matplotlib.org/). 12 | 13 | If you've already got TensorFlow on your machine, then you've got NumPy and should be able to run the raw Python scripts. 14 | 15 | ### Installing Anaconda Python and TensorFlow 16 | The easiest way to install TensorFlow as well as NumPy, Jupyter, and matplotlib is to start with the Anaconda Python distribution. 17 | 18 | 1. Follow the [installation instructions for Anaconda Python](https://www.continuum.io/downloads). **We recommend using Python 3.6.** 19 | 20 | 2. Follow the platform-specific [TensorFlow installation instructions](https://www.tensorflow.org/install/). Be sure to follow the "Installing with Anaconda" process, and create a Conda environment named `tensorflow`. 21 | 22 | 3. If you aren't still inside your Conda TensorFlow environment, enter it by opening your terminal and typing 23 | ```bash 24 | source activate tensorflow 25 | ``` 26 | 27 | 4. Download and unzip [this entire repository from GitHub](https://github.com/jonbruner/generative-adversarial-networks), either interactively, or by entering 28 | ```bash 29 | git clone https://github.com/jonbruner/generative-adversarial-networks.git 30 | ``` 31 | 32 | 5. Use `cd` to navigate into the top directory of the repo on your machine 33 | 34 | 6. Launch Jupyter by entering 35 | ```bash 36 | jupyter notebook 37 | ``` 38 | and, using your browser, navigate to the URL shown in the terminal output (usually http://localhost:8888/) 39 | -------------------------------------------------------------------------------- /gan-notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Generative Adversarial Networks for Beginners\n", 8 | "## Build a neural network that learns to generate handwritten digits.\n", 9 | "### By [Jon Bruner](https://github.com/jonbruner) and [Adit Deshpande](https://github.com/adeshpande3)\n", 10 | "\n", 11 | "This notebook accompanies [the O'Reilly interactive tutorial on generative adversarial networks](https://www.oreilly.com/learning/generative-adversarial-networks-for-beginners). See the original tutorial to run this code in a pre-built environment on O'Reilly's servers with cell-by-cell guidance, or run this notebook on your own machine.\n", 12 | "\n", 13 | "Also, see [gan-script.py](gan-script.py) in this repository for a straight Python implementation of this code.\n", 14 | "\n", 15 | "### Prerequisites\n", 16 | "\n", 17 | "You'll need [TensorFlow](https://www.tensorflow.org/install/), [NumPy](https://docs.scipy.org/doc/numpy/user/install.html), [matplotlib](https://matplotlib.org/) and [Jupyter](https://jupyter.readthedocs.io/en/latest/install.html) in order to run this notebook on your machine. See [the readme](https://github.com/jonbruner/generative-adversarial-networks) for advice on installing these packages." 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "## Introduction\n", 25 | "\n", 26 | "According to Yann LeCun, “adversarial training is the coolest thing since sliced bread.” Sliced bread certainly never created this much excitement within the deep learning community. Generative adversarial networks—or GANs, for short—have dramatically sharpened the possibility of AI-generated content, and have drawn active research efforts since they were [first described by Ian Goodfellow et al. in 2014](https://arxiv.org/abs/1406.2661).\n", 27 | "\n", 28 | "GANs are neural networks that learn to create synthetic data similar to some known input data. For instance, researchers have generated convincing images from [photographs of everything from bedrooms to album covers](https://github.com/Newmu/dcgan_code), and they display a remarkable ability to reflect [higher-order semantic logic](https://github.com/Newmu/dcgan_code).\n", 29 | "\n", 30 | "Those examples are fairly complex, but it's easy to build a GAN that generates very simple images. In this tutorial, we'll build a GAN that analyzes lots of images of handwritten digits and gradually learns to generate new images from scratch—*essentially, we'll be teaching a neural network how to write*.\n", 31 | "\n", 32 | "\n", 33 | "_Sample images from the generative adversarial network that we'll build in this tutorial. During training, it gradually refines its ability to generate digits._" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## GAN architecture\n", 41 | "\n", 42 | "Generative adversarial networks consist of two models: a generative model and a discriminative model.\n", 43 | "\n", 44 | "![caption](notebook-images/GAN_Overall.png)\n", 45 | "\n", 46 | "The discriminator model is a classifier that determines whether a given image looks like a real image from the dataset or like an artificially created image. This is basically a binary classifier that will take the form of a normal convolutional neural network (CNN).\n", 47 | "\n", 48 | "The generator model takes random input values and transforms them into images through a deconvolutional neural network.\n", 49 | "\n", 50 | "Over the course of many training iterations, the weights and biases in the discriminator and the generator are trained through backpropagation. The discriminator learns to tell \"real\" images of handwritten digits apart from \"fake\" images created by the generator. At the same time, the generator uses feedback from the discriminator to learn how to produce convincing images that the discriminator can't distinguish from real images." 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "## Getting started\n", 58 | "\n", 59 | "We’re going to create a GAN that will generate handwritten digits that can fool even the best classifiers (and humans too, of course). We'll use [TensorFlow](https://www.tensorflow.org/), a deep learning library open-sourced by Google that makes it easy to train neural networks on GPUs.\n", 60 | "\n", 61 | "This tutorial expects that you're already at least a little bit familiar with TensorFlow. If you're not, we recommend reading \"[Hello, TensorFlow!](https://www.oreilly.com/learning/hello-tensorflow)\" or watching the \"[Hello, Tensorflow!](https://www.safaribooksonline.com/oriole/hello-tensorflow-oriole)\" interactive tutorial on Safari before proceeding." 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "## Loading MNIST data\n", 69 | "\n", 70 | "We need a set of real handwritten digits to give the discriminator a starting point in distinguishing between real and fake images. We'll use [MNIST](http://yann.lecun.com/exdb/mnist/), a benchmark dataset in deep learning. It consists of 70,000 images of handwritten digits compiled by the U.S. National Institute of Standards and Technology from Census Bureau employees and high school students.\n", 71 | "\n", 72 | "Let's start by importing TensorFlow along with a couple of other helpful libraries. We'll also import our MNIST images using a TensorFlow convenience function called `read_data_sets`." 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "import tensorflow as tf\n", 82 | "import numpy as np\n", 83 | "import datetime\n", 84 | "import matplotlib.pyplot as plt\n", 85 | "%matplotlib inline\n", 86 | "\n", 87 | "from tensorflow.examples.tutorials.mnist import input_data\n", 88 | "mnist = input_data.read_data_sets(\"MNIST_data/\")" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "The MNIST variable we created above contains both the images and their labels, divided into a training set called `train` and a validation set called `validation`. (We won't need to worry about the labels in this tutorial.) We can retrieve batches of images by calling `next_batch` on `mnist`. Let's load one image and look at it.\n", 96 | "\n", 97 | "The images are initially formatted as a single row of 784 pixels. We can reshape them into 28 x 28-pixel images and view them using pyplot." 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": { 104 | "scrolled": false 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "sample_image = mnist.train.next_batch(1)[0]\n", 109 | "print(sample_image.shape)\n", 110 | "\n", 111 | "sample_image = sample_image.reshape([28, 28])\n", 112 | "plt.imshow(sample_image, cmap='Greys')" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "If you run the cell above again, you'll see a different image from the MNIST training set." 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "## Discriminator network\n", 127 | "\n", 128 | "Our discriminator is a convolutional neural network that takes in an image of size 28 x 28 x 1 as input and returns a single scalar number that describes whether or not the input image is \"real\" or \"fake\"—that is, whether it's drawn from the set of MNIST images or generated by the generator.\n", 129 | "\n", 130 | "![caption](notebook-images/GAN_Discriminator.png)\n", 131 | "\n", 132 | "The structure of our discriminator network is based closely on [TensorFlow's sample CNN classifier model](https://www.tensorflow.org/get_started/mnist/pros). It features two convolutional layers that find 5x5-pixel features, and two \"fully connected\" layers that multiply weights by every pixel in the image.\n", 133 | "\n", 134 | "To set up each layer, we start by creating weight and bias variables through [`tf.get_variable`](https://www.tensorflow.org/api_docs/python/tf/get_variable). Weights are initialized from a [truncated normal](https://www.tensorflow.org/api_docs/python/tf/truncated_normal) distribution, and biases are initialized at zero.\n", 135 | "\n", 136 | "[`tf.nn.conv2d()`](https://www.tensorflow.org/api_docs/python/tf/nn/conv2d) is TensorFlow's standard convolution function. It takes 4 arguments. The first is the input volume (our `28 x 28 x 1` images in this case). The next argument is the filter/weight matrix. Finally, you can also change the stride and padding of the convolution. Those two values affect the dimensions of the output volume.\n", 137 | "\n", 138 | "If you're already comfortable with CNNs, you'll recognize this as a simple binary classifier—nothing fancy. " 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "def discriminator(images, reuse_variables=None):\n", 148 | " with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables) as scope:\n", 149 | " # First convolutional and pool layers\n", 150 | " # This finds 32 different 5 x 5 pixel features\n", 151 | " d_w1 = tf.get_variable('d_w1', [5, 5, 1, 32], initializer=tf.truncated_normal_initializer(stddev=0.02))\n", 152 | " d_b1 = tf.get_variable('d_b1', [32], initializer=tf.constant_initializer(0))\n", 153 | " d1 = tf.nn.conv2d(input=images, filter=d_w1, strides=[1, 1, 1, 1], padding='SAME')\n", 154 | " d1 = d1 + d_b1\n", 155 | " d1 = tf.nn.relu(d1)\n", 156 | " d1 = tf.nn.avg_pool(d1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')\n", 157 | "\n", 158 | " # Second convolutional and pool layers\n", 159 | " # This finds 64 different 5 x 5 pixel features\n", 160 | " d_w2 = tf.get_variable('d_w2', [5, 5, 32, 64], initializer=tf.truncated_normal_initializer(stddev=0.02))\n", 161 | " d_b2 = tf.get_variable('d_b2', [64], initializer=tf.constant_initializer(0))\n", 162 | " d2 = tf.nn.conv2d(input=d1, filter=d_w2, strides=[1, 1, 1, 1], padding='SAME')\n", 163 | " d2 = d2 + d_b2\n", 164 | " d2 = tf.nn.relu(d2)\n", 165 | " d2 = tf.nn.avg_pool(d2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')\n", 166 | "\n", 167 | " # First fully connected layer\n", 168 | " d_w3 = tf.get_variable('d_w3', [7 * 7 * 64, 1024], initializer=tf.truncated_normal_initializer(stddev=0.02))\n", 169 | " d_b3 = tf.get_variable('d_b3', [1024], initializer=tf.constant_initializer(0))\n", 170 | " d3 = tf.reshape(d2, [-1, 7 * 7 * 64])\n", 171 | " d3 = tf.matmul(d3, d_w3)\n", 172 | " d3 = d3 + d_b3\n", 173 | " d3 = tf.nn.relu(d3)\n", 174 | "\n", 175 | " # Second fully connected layer\n", 176 | " d_w4 = tf.get_variable('d_w4', [1024, 1], initializer=tf.truncated_normal_initializer(stddev=0.02))\n", 177 | " d_b4 = tf.get_variable('d_b4', [1], initializer=tf.constant_initializer(0))\n", 178 | " d4 = tf.matmul(d3, d_w4) + d_b4\n", 179 | "\n", 180 | " # d4 contains unscaled values\n", 181 | " return d4" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "metadata": {}, 187 | "source": [ 188 | "## Generator network\n", 189 | "\n", 190 | "![caption](notebook-images/GAN_Generator.png)\n", 191 | "\n", 192 | "Now that we have our discriminator defined, let’s take a look at the generator model. We'll base the overall structure of our model on a simple generator published by [Tim O'Shea](https://github.com/osh/KerasGAN).\n", 193 | "\n", 194 | "You can think of the generator as a kind of reverse convolutional neural network. A typical CNN like our discriminator network transforms a 2- or 3-dimensional matrix of pixel values into a single probability. A generator, however, takes a `d`-dimensional vector of noise and upsamples it to become a 28 x 28 image. ReLU and batch normalization are used to stabilize the outputs of each layer.\n", 195 | "\n", 196 | "In our generator network, we use three convolutional layers along with interpolation until a `28 x 28` pixel image is formed. (Actually, as you'll see below, we've taken care to form `28 x 28 x 1` images; many TensorFlow tools for dealing with images anticipate that the images will have some number of _channels_—usually 1 for greyscale images or 3 for RGB color images.)\n", 197 | "\n", 198 | "At the output layer we add a [`tf.sigmoid()`](https://www.tensorflow.org/api_docs/python/tf/sigmoid) activation function; this squeezes pixels that would appear grey toward either black or white, resulting in a crisper image." 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "def generator(z, batch_size, z_dim):\n", 208 | " g_w1 = tf.get_variable('g_w1', [z_dim, 3136], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02))\n", 209 | " g_b1 = tf.get_variable('g_b1', [3136], initializer=tf.truncated_normal_initializer(stddev=0.02))\n", 210 | " g1 = tf.matmul(z, g_w1) + g_b1\n", 211 | " g1 = tf.reshape(g1, [-1, 56, 56, 1])\n", 212 | " g1 = tf.contrib.layers.batch_norm(g1, epsilon=1e-5, scope='g_b1')\n", 213 | " g1 = tf.nn.relu(g1)\n", 214 | "\n", 215 | " # Generate 50 features\n", 216 | " g_w2 = tf.get_variable('g_w2', [3, 3, 1, z_dim/2], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02))\n", 217 | " g_b2 = tf.get_variable('g_b2', [z_dim/2], initializer=tf.truncated_normal_initializer(stddev=0.02))\n", 218 | " g2 = tf.nn.conv2d(g1, g_w2, strides=[1, 2, 2, 1], padding='SAME')\n", 219 | " g2 = g2 + g_b2\n", 220 | " g2 = tf.contrib.layers.batch_norm(g2, epsilon=1e-5, scope='g_b2')\n", 221 | " g2 = tf.nn.relu(g2)\n", 222 | " g2 = tf.image.resize_images(g2, [56, 56])\n", 223 | "\n", 224 | " # Generate 25 features\n", 225 | " g_w3 = tf.get_variable('g_w3', [3, 3, z_dim/2, z_dim/4], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02))\n", 226 | " g_b3 = tf.get_variable('g_b3', [z_dim/4], initializer=tf.truncated_normal_initializer(stddev=0.02))\n", 227 | " g3 = tf.nn.conv2d(g2, g_w3, strides=[1, 2, 2, 1], padding='SAME')\n", 228 | " g3 = g3 + g_b3\n", 229 | " g3 = tf.contrib.layers.batch_norm(g3, epsilon=1e-5, scope='g_b3')\n", 230 | " g3 = tf.nn.relu(g3)\n", 231 | " g3 = tf.image.resize_images(g3, [56, 56])\n", 232 | "\n", 233 | " # Final convolution with one output channel\n", 234 | " g_w4 = tf.get_variable('g_w4', [1, 1, z_dim/4, 1], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02))\n", 235 | " g_b4 = tf.get_variable('g_b4', [1], initializer=tf.truncated_normal_initializer(stddev=0.02))\n", 236 | " g4 = tf.nn.conv2d(g3, g_w4, strides=[1, 2, 2, 1], padding='SAME')\n", 237 | " g4 = g4 + g_b4\n", 238 | " g4 = tf.sigmoid(g4)\n", 239 | " \n", 240 | " # Dimensions of g4: batch_size x 28 x 28 x 1\n", 241 | " return g4" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "metadata": {}, 247 | "source": [ 248 | "## Generating a sample image\n", 249 | "\n", 250 | "Now we’ve defined both the generator and discriminator functions. Let’s see what a sample output from an untrained generator looks like.\n", 251 | "\n", 252 | "We need to open a TensorFlow session and create a placeholder for the input to our generator. The shape of the placeholder will be `None, z_dimensions`. The `None` keyword means that the value can be determined at session runtime. We normally have `None` as our first dimension so that we can have variable batch sizes. (With a batch size of 50, the input to the generator would be 50 x 100). With the `None` keywoard, we don't have to specify `batch_size` until later. " 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": { 259 | "collapsed": true 260 | }, 261 | "outputs": [], 262 | "source": [ 263 | "z_dimensions = 100\n", 264 | "z_placeholder = tf.placeholder(tf.float32, [None, z_dimensions])" 265 | ] 266 | }, 267 | { 268 | "cell_type": "markdown", 269 | "metadata": {}, 270 | "source": [ 271 | "Now, we create a variable (`generated_image_output`) that holds the output of the generator, and we'll also initialize the random noise vector that we're going to use as input. The [`np.random.normal()`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.random.normal.html) function has three arguments. The first and second define the mean and standard deviation for the normal distribution (0 and 1 in our case), and the third defines the the shape of the vector (`1 x 100`)." 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [ 280 | "generated_image_output = generator(z_placeholder, 1, z_dimensions)\n", 281 | "z_batch = np.random.normal(0, 1, [1, z_dimensions])" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "Next, we initialize all the variables, feed our `z_batch` into the placeholder, and run the session.\n", 289 | "\n", 290 | "The [`sess.run()`](https://www.tensorflow.org/api_docs/python/tf/Session#run) function has two arguments. The first is called the \"fetches\" argument; it defines the value you're interested in computing. In our case, we want to see what the output of the generator is. If you look back at the last code snippet, you'll see that the output of the generator function is stored in `generated_image_output`, so we'll use `generated_image_output` for our first argument.\n", 291 | "\n", 292 | "The second argument takes a dictionary of inputs that are substituted into the graph when it runs. This is where we feed in our placeholders. In our example, we need to feed our `z_batch` variable into the `z_placeholder` that we defined earlier. As before, we'll view the image by reshaping it to `28 x 28` pixels and show it with PyPlot." 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "with tf.Session() as sess:\n", 302 | " sess.run(tf.global_variables_initializer())\n", 303 | " generated_image = sess.run(generated_image_output,\n", 304 | " feed_dict={z_placeholder: z_batch})\n", 305 | " generated_image = generated_image.reshape([28, 28])\n", 306 | " plt.imshow(generated_image, cmap='Greys')" 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "That looks like noise, right? Now we need to train the weights and biases in the generator network to convert random numbers into recognizable digits. Let's look at loss functions and optimization!" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": {}, 319 | "source": [ 320 | "## Training a GAN\n", 321 | "\n", 322 | "One of the trickiest parts about building and tuning GANs is that they have two loss functions: one that encourages the generator to create better images, and the other that encourages the discriminator to distinguish generated images from real images.\n", 323 | "\n", 324 | "We train both the generator and the discriminator simultaneously. As the discriminator gets better at distinguishing real images from generated images, the generator is able to better tune its weights and biases to generate convincing images.\n", 325 | "\n", 326 | "Here are the inputs and outputs for our networks." 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": null, 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "tf.reset_default_graph()\n", 336 | "batch_size = 50\n", 337 | "\n", 338 | "z_placeholder = tf.placeholder(tf.float32, [None, z_dimensions], name='z_placeholder') \n", 339 | "# z_placeholder is for feeding input noise to the generator\n", 340 | "\n", 341 | "x_placeholder = tf.placeholder(tf.float32, shape = [None,28,28,1], name='x_placeholder') \n", 342 | "# x_placeholder is for feeding input images to the discriminator\n", 343 | "\n", 344 | "Gz = generator(z_placeholder, batch_size, z_dimensions) \n", 345 | "# Gz holds the generated images\n", 346 | "\n", 347 | "Dx = discriminator(x_placeholder) \n", 348 | "# Dx will hold discriminator prediction probabilities\n", 349 | "# for the real MNIST images\n", 350 | "\n", 351 | "Dg = discriminator(Gz, reuse_variables=True)\n", 352 | "# Dg will hold discriminator prediction probabilities for generated images" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "metadata": {}, 358 | "source": [ 359 | "So, let’s first think about what we want out of our networks. The discriminator's goal is to correctly label real MNIST images as real (return a higher output) and generated images as fake (return a lower output). We'll calculate two losses for the discriminator: one loss that compares `Dx` and 1 for real images from the MNIST set, as well as a loss that compares `Dg` and 0 for images from the generator. We'll do this with TensorFlow's [`tf.nn.sigmoid_cross_entropy_with_logits()`](https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits) function, which calculates the cross-entropy losses between `Dx` and 1 and between `Dg` and 0.\n", 360 | "\n", 361 | "`sigmoid_cross_entropy_with_logits` operates on unscaled values rather than probability values from 0 to 1. Take a look at the last line of our discriminator: there's no softmax or sigmoid layer at the end. GANs can fail if their discriminators \"saturate,\" or become confident enough to return exactly 0 when they're given a generated image; that leaves the discriminator without a useful gradient to descend.\n", 362 | "\n", 363 | "The [`tf.reduce_mean()`](https://www.tensorflow.org/api_docs/python/tf/reduce_mean) function takes the mean value of all of the components in the matrix returned by the cross entropy function. This is a way of reducing the loss to a single scalar value, instead of a vector or matrix." 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dx, labels = tf.ones_like(Dx)))\n", 373 | "d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.zeros_like(Dg)))" 374 | ] 375 | }, 376 | { 377 | "cell_type": "markdown", 378 | "metadata": {}, 379 | "source": [ 380 | "Now let's set up the generator's loss function. We want the generator network to create images that will fool the discriminator: the generator wants the discriminator to output a value close to 1 when it's given an image from the generator. Therefore, we want to compute the loss between `Dg` and 1." 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "metadata": {}, 387 | "outputs": [], 388 | "source": [ 389 | "g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.ones_like(Dg))) " 390 | ] 391 | }, 392 | { 393 | "cell_type": "markdown", 394 | "metadata": {}, 395 | "source": [ 396 | "Now that we have our loss functions, we need to define our optimizers. The optimizer for the generator network needs to only update the generator’s weights, not those of the discriminator. Likewise, when we train the discriminator, we want to hold the generator's weights fixed.\n", 397 | "\n", 398 | "In order to make this distinction, we need to create two lists of variables, one with the discriminator’s weights and biases and another with the generator’s weights and biases. This is where naming all of your TensorFlow variables with a thoughtful scheme can come in handy." 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": null, 404 | "metadata": {}, 405 | "outputs": [], 406 | "source": [ 407 | "tvars = tf.trainable_variables()\n", 408 | "\n", 409 | "d_vars = [var for var in tvars if 'd_' in var.name]\n", 410 | "g_vars = [var for var in tvars if 'g_' in var.name]\n", 411 | "\n", 412 | "print([v.name for v in d_vars])\n", 413 | "print([v.name for v in g_vars])" 414 | ] 415 | }, 416 | { 417 | "cell_type": "markdown", 418 | "metadata": {}, 419 | "source": [ 420 | "Next, we specify our two optimizers. [Adam](https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer) is usually the optimization algorithm of choice for GANs; it utilizes adaptive learning rates and momentum. We call Adam's minimize function and also specify the variables that we want it to update—the generator's weights and biases when we train the generator, and the discriminator's weights and biases when we train the discriminator.\n", 421 | "\n", 422 | "We're setting up two different training operations for the discriminator here: one that trains the discriminator on real images and one that trains the discrmnator on fake images. It's sometimes useful to use different learning rates for these two training operations, or to use them separately to [regulate learning in other ways](https://github.com/jonbruner/ezgan)." 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": null, 428 | "metadata": {}, 429 | "outputs": [], 430 | "source": [ 431 | "# Train the discriminator\n", 432 | "d_trainer_fake = tf.train.AdamOptimizer(0.0003).minimize(d_loss_fake, var_list=d_vars)\n", 433 | "d_trainer_real = tf.train.AdamOptimizer(0.0003).minimize(d_loss_real, var_list=d_vars)\n", 434 | "\n", 435 | "# Train the generator\n", 436 | "g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": {}, 442 | "source": [ 443 | "It can be tricky to get GANs to converge, and moreover they often need to train for a very long time. [TensorBoard](https://www.tensorflow.org/how_tos/summaries_and_tensorboard/) is useful for tracking the training process; it can graph scalar properties like losses, display sample images during training, and illustrate the topology of the neural networks.\n", 444 | "\n", 445 | "If you run this script on your own machine, include the cell below. Then, in a terminal window from the directory that this notebook lives in, run\n", 446 | "\n", 447 | "```\n", 448 | "tensorboard --logdir=tensorboard/\n", 449 | "```\n", 450 | "\n", 451 | "and open TensorBoard by visiting [`http://localhost:6006`](http://localhost:6006) in your web browser." 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": null, 457 | "metadata": {}, 458 | "outputs": [], 459 | "source": [ 460 | "# From this point forward, reuse variables\n", 461 | "tf.get_variable_scope().reuse_variables()\n", 462 | "\n", 463 | "tf.summary.scalar('Generator_loss', g_loss)\n", 464 | "tf.summary.scalar('Discriminator_loss_real', d_loss_real)\n", 465 | "tf.summary.scalar('Discriminator_loss_fake', d_loss_fake)\n", 466 | "\n", 467 | "images_for_tensorboard = generator(z_placeholder, batch_size, z_dimensions)\n", 468 | "tf.summary.image('Generated_images', images_for_tensorboard, 5)\n", 469 | "merged = tf.summary.merge_all()\n", 470 | "logdir = \"tensorboard/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\") + \"/\"\n", 471 | "writer = tf.summary.FileWriter(logdir, sess.graph)" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": {}, 477 | "source": [ 478 | "And now we iterate. We begin by briefly giving the discriminator some initial training; this helps it develop a gradient that's useful to the generator.\n", 479 | "\n", 480 | "Then we move on to the main training loop. When we train the generator, we’ll feed a random `z` vector into the generator and pass its output to the discriminator (this is the `Dg` variable we specified earlier). The generator’s weights and biases will be updated in order to produce images that the discriminator is more likely to classify as real.\n", 481 | "\n", 482 | "To train the discriminator, we’ll feed it a batch of images from the MNIST set to serve as the positive examples, and then train the discriminator again on generated images, using them as negative examples. Remember that as the generator improves its output, the discriminator continues to learn to classify the improved generator images as fake.\n", 483 | "\n", 484 | "Because it takes a long time to train a GAN, **we recommend not running this code block if you're going through this tutorial for the first time**. Instead, follow along but then run the following code block, which loads a pre-trained model for us to continue the tutorial.\n", 485 | "\n", 486 | "**If you want to run this code yourself, prepare to wait: it takes about three hours on a fast GPU, but could take ten times that long on a desktop CPU.**" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": null, 492 | "metadata": { 493 | "scrolled": true 494 | }, 495 | "outputs": [], 496 | "source": [ 497 | "sess = tf.Session()\n", 498 | "sess.run(tf.global_variables_initializer())\n", 499 | "\n", 500 | "# Pre-train discriminator\n", 501 | "for i in range(300):\n", 502 | " z_batch = np.random.normal(0, 1, size=[batch_size, z_dimensions])\n", 503 | " real_image_batch = mnist.train.next_batch(batch_size)[0].reshape([batch_size, 28, 28, 1])\n", 504 | " _, __, dLossReal, dLossFake = sess.run([d_trainer_real, d_trainer_fake, d_loss_real, d_loss_fake],\n", 505 | " {x_placeholder: real_image_batch, z_placeholder: z_batch})\n", 506 | "\n", 507 | " if(i % 100 == 0):\n", 508 | " print(\"dLossReal:\", dLossReal, \"dLossFake:\", dLossFake)\n", 509 | "\n", 510 | "# Train generator and discriminator together\n", 511 | "for i in range(100000):\n", 512 | " real_image_batch = mnist.train.next_batch(batch_size)[0].reshape([batch_size, 28, 28, 1])\n", 513 | " z_batch = np.random.normal(0, 1, size=[batch_size, z_dimensions])\n", 514 | "\n", 515 | " # Train discriminator on both real and fake images\n", 516 | " _, __, dLossReal, dLossFake = sess.run([d_trainer_real, d_trainer_fake, d_loss_real, d_loss_fake],\n", 517 | " {x_placeholder: real_image_batch, z_placeholder: z_batch})\n", 518 | "\n", 519 | " # Train generator\n", 520 | " z_batch = np.random.normal(0, 1, size=[batch_size, z_dimensions])\n", 521 | " _ = sess.run(g_trainer, feed_dict={z_placeholder: z_batch})\n", 522 | "\n", 523 | " if i % 10 == 0:\n", 524 | " # Update TensorBoard with summary statistics\n", 525 | " z_batch = np.random.normal(0, 1, size=[batch_size, z_dimensions])\n", 526 | " summary = sess.run(merged, {z_placeholder: z_batch, x_placeholder: real_image_batch})\n", 527 | " writer.add_summary(summary, i)\n", 528 | "\n", 529 | " if i % 100 == 0:\n", 530 | " # Every 100 iterations, show a generated image\n", 531 | " print(\"Iteration:\", i, \"at\", datetime.datetime.now())\n", 532 | " z_batch = np.random.normal(0, 1, size=[1, z_dimensions])\n", 533 | " generated_images = generator(z_placeholder, 1, z_dimensions)\n", 534 | " images = sess.run(generated_images, {z_placeholder: z_batch})\n", 535 | " plt.imshow(images[0].reshape([28, 28]), cmap='Greys')\n", 536 | " plt.show()\n", 537 | "\n", 538 | " # Show discriminator's estimate\n", 539 | " im = images[0].reshape([1, 28, 28, 1])\n", 540 | " result = discriminator(x_placeholder)\n", 541 | " estimate = sess.run(result, {x_placeholder: im})\n", 542 | " print(\"Estimate:\", estimate)" 543 | ] 544 | }, 545 | { 546 | "cell_type": "markdown", 547 | "metadata": {}, 548 | "source": [ 549 | "Because it can take so long to train a GAN, we recommend that you skip the cell above and execute the following cell. It loads a model that we've already trained for several hours on a fast GPU machine, and lets you experiment with the output of a trained GAN." 550 | ] 551 | }, 552 | { 553 | "cell_type": "code", 554 | "execution_count": null, 555 | "metadata": {}, 556 | "outputs": [], 557 | "source": [ 558 | "saver = tf.train.Saver()\n", 559 | "with tf.Session() as sess:\n", 560 | " saver.restore(sess, 'pretrained-model/pretrained_gan.ckpt')\n", 561 | " z_batch = np.random.normal(0, 1, size=[10, z_dimensions])\n", 562 | " z_placeholder = tf.placeholder(tf.float32, [None, z_dimensions], name='z_placeholder') \n", 563 | " generated_images = generator(z_placeholder, 10, z_dimensions)\n", 564 | " images = sess.run(generated_images, {z_placeholder: z_batch})\n", 565 | " for i in range(10):\n", 566 | " plt.imshow(images[i].reshape([28, 28]), cmap='Greys')\n", 567 | " plt.show()" 568 | ] 569 | }, 570 | { 571 | "cell_type": "markdown", 572 | "metadata": {}, 573 | "source": [ 574 | "## Training difficulties\n", 575 | "\n", 576 | "GANs are notoriously difficult to train. Without the right hyperparameters, network architecture, and training procedure, the discriminator can overpower the generator, or vice-versa.\n", 577 | "\n", 578 | "In one common failure mode, the discriminator overpowers the generator, classifying generated images as fake with absolute certainty. When the discriminator responds with absolute certainty, it leaves no gradient for the generator to descend. This is partly why we built our discriminator to produce unscaled output rather than passing its output through a sigmoid function that would push its evaluation toward either 0 or 1.\n", 579 | "\n", 580 | "In another common failure mode known as **mode collapse**, the generator discovers and exploits some weakness in the discriminator. You can recognize mode collapse in your GAN if it generates many very similar images regardless of variation in the generator input _z_. Mode collapse can sometimes be corrected by \"strengthening\" the discriminator in some way—for instance, by adjusting its training rate or by reconfiguring its layers.\n", 581 | "\n", 582 | "Researchers have identified a handful of [\"GAN hacks\"](https://github.com/soumith/ganhacks) that can be helpful in building stable GANs." 583 | ] 584 | }, 585 | { 586 | "cell_type": "markdown", 587 | "metadata": {}, 588 | "source": [ 589 | "## Closing thoughts\n", 590 | "\n", 591 | "GANs have tremendous potential to reshape the digital world that we interact with every day. The field is still very young, and the next great GAN discovery could be yours!" 592 | ] 593 | }, 594 | { 595 | "cell_type": "markdown", 596 | "metadata": {}, 597 | "source": [ 598 | "## Other resources\n", 599 | "\n", 600 | "- [The original GAN paper](https://arxiv.org/abs/1406.2661) by Ian Goodfellow and his collaborators, published in 2014\n", 601 | "- [A more recent tutorial by Goodfellow](https://arxiv.org/abs/1701.00160) that explains GANs in somewhat more accessible terms\n", 602 | "- [A paper by Alec Radford, Luke Metz, and Soumith Chintala](https://arxiv.org/abs/1511.06434) that introduces deep convolutional GANs, whose basic structure we use in our generator in this tutorial. Also see [their DCGAN code on GitHub](https://github.com/Newmu/dcgan_code).\n", 603 | "- [A reference collection of generative networks by Agustinus Kristiadi](https://github.com/wiseodd/generative-models), implemented in TensorFlow" 604 | ] 605 | } 606 | ], 607 | "metadata": { 608 | "anaconda-cloud": {}, 609 | "kernelspec": { 610 | "display_name": "Python 3", 611 | "language": "python", 612 | "name": "python3" 613 | }, 614 | "language_info": { 615 | "codemirror_mode": { 616 | "name": "ipython", 617 | "version": 3 618 | }, 619 | "file_extension": ".py", 620 | "mimetype": "text/x-python", 621 | "name": "python", 622 | "nbconvert_exporter": "python", 623 | "pygments_lexer": "ipython3", 624 | "version": "3.6.3" 625 | } 626 | }, 627 | "nbformat": 4, 628 | "nbformat_minor": 1 629 | } 630 | -------------------------------------------------------------------------------- /gan-script-fast.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a straightforward Python implementation of a generative adversarial network. 3 | The code is derived from the O'Reilly interactive tutorial on GANs 4 | (https://www.oreilly.com/learning/generative-adversarial-networks-for-beginners). 5 | 6 | The tutorial's code trades efficiency for clarity in explaining how GANs function; 7 | this script refactors a few things to improve performance, especially on GPU machines. 8 | In particular, it uses a TensorFlow operation to generate random z values and pass them 9 | to the generator; this way, more computations are contained entirely within the 10 | TensorFlow graph. 11 | 12 | A version of this model with explanatory notes is also available on GitHub 13 | at https://github.com/jonbruner/generative-adversarial-networks. 14 | 15 | This script requires TensorFlow and its dependencies in order to run. Please see 16 | the readme for guidance on installing TensorFlow. 17 | 18 | This script won't print summary statistics in the terminal during training; 19 | track progress and see sample images in TensorBoard. 20 | """ 21 | 22 | import tensorflow as tf 23 | import datetime 24 | 25 | # Load MNIST data 26 | from tensorflow.examples.tutorials.mnist import input_data 27 | mnist = input_data.read_data_sets("MNIST_data/") 28 | 29 | # Define the discriminator network 30 | def discriminator(images, reuse_variables=None): 31 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables) as scope: 32 | # First convolutional and pool layers 33 | # This finds 32 different 5 x 5 pixel features 34 | d_w1 = tf.get_variable('d_w1', [5, 5, 1, 32], initializer=tf.truncated_normal_initializer(stddev=0.02)) 35 | d_b1 = tf.get_variable('d_b1', [32], initializer=tf.constant_initializer(0)) 36 | d1 = tf.nn.conv2d(input=images, filter=d_w1, strides=[1, 1, 1, 1], padding='SAME') 37 | d1 = d1 + d_b1 38 | d1 = tf.nn.relu(d1) 39 | d1 = tf.nn.avg_pool(d1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 40 | 41 | # Second convolutional and pool layers 42 | # This finds 64 different 5 x 5 pixel features 43 | d_w2 = tf.get_variable('d_w2', [5, 5, 32, 64], initializer=tf.truncated_normal_initializer(stddev=0.02)) 44 | d_b2 = tf.get_variable('d_b2', [64], initializer=tf.constant_initializer(0)) 45 | d2 = tf.nn.conv2d(input=d1, filter=d_w2, strides=[1, 1, 1, 1], padding='SAME') 46 | d2 = d2 + d_b2 47 | d2 = tf.nn.relu(d2) 48 | d2 = tf.nn.avg_pool(d2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 49 | 50 | # First fully connected layer 51 | d_w3 = tf.get_variable('d_w3', [7 * 7 * 64, 1024], initializer=tf.truncated_normal_initializer(stddev=0.02)) 52 | d_b3 = tf.get_variable('d_b3', [1024], initializer=tf.constant_initializer(0)) 53 | d3 = tf.reshape(d2, [-1, 7 * 7 * 64]) 54 | d3 = tf.matmul(d3, d_w3) 55 | d3 = d3 + d_b3 56 | d3 = tf.nn.relu(d3) 57 | 58 | # Second fully connected layer 59 | d_w4 = tf.get_variable('d_w4', [1024, 1], initializer=tf.truncated_normal_initializer(stddev=0.02)) 60 | d_b4 = tf.get_variable('d_b4', [1], initializer=tf.constant_initializer(0)) 61 | d4 = tf.matmul(d3, d_w4) + d_b4 62 | 63 | # d4 contains unscaled values 64 | return d4 65 | 66 | # Define the generator network 67 | def generator(batch_size, z_dim): 68 | z = tf.random_normal([batch_size, z_dim], mean=0, stddev=1, name='z') 69 | g_w1 = tf.get_variable('g_w1', [z_dim, 3136], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02)) 70 | g_b1 = tf.get_variable('g_b1', [3136], initializer=tf.truncated_normal_initializer(stddev=0.02)) 71 | g1 = tf.matmul(z, g_w1) + g_b1 72 | g1 = tf.reshape(g1, [-1, 56, 56, 1]) 73 | g1 = tf.contrib.layers.batch_norm(g1, epsilon=1e-5, scope='g_b1') 74 | g1 = tf.nn.relu(g1) 75 | 76 | # Generate 50 features 77 | g_w2 = tf.get_variable('g_w2', [3, 3, 1, z_dim/2], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02)) 78 | g_b2 = tf.get_variable('g_b2', [z_dim/2], initializer=tf.truncated_normal_initializer(stddev=0.02)) 79 | g2 = tf.nn.conv2d(g1, g_w2, strides=[1, 2, 2, 1], padding='SAME') 80 | g2 = g2 + g_b2 81 | g2 = tf.contrib.layers.batch_norm(g2, epsilon=1e-5, scope='g_b2') 82 | g2 = tf.nn.relu(g2) 83 | g2 = tf.image.resize_images(g2, [56, 56]) 84 | 85 | # Generate 25 features 86 | g_w3 = tf.get_variable('g_w3', [3, 3, z_dim/2, z_dim/4], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02)) 87 | g_b3 = tf.get_variable('g_b3', [z_dim/4], initializer=tf.truncated_normal_initializer(stddev=0.02)) 88 | g3 = tf.nn.conv2d(g2, g_w3, strides=[1, 2, 2, 1], padding='SAME') 89 | g3 = g3 + g_b3 90 | g3 = tf.contrib.layers.batch_norm(g3, epsilon=1e-5, scope='g_b3') 91 | g3 = tf.nn.relu(g3) 92 | g3 = tf.image.resize_images(g3, [56, 56]) 93 | 94 | # Final convolution with one output channel 95 | g_w4 = tf.get_variable('g_w4', [1, 1, z_dim/4, 1], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02)) 96 | g_b4 = tf.get_variable('g_b4', [1], initializer=tf.truncated_normal_initializer(stddev=0.02)) 97 | g4 = tf.nn.conv2d(g3, g_w4, strides=[1, 2, 2, 1], padding='SAME') 98 | g4 = g4 + g_b4 99 | g4 = tf.sigmoid(g4) 100 | 101 | # Dimensions of g4: batch_size x 28 x 28 x 1 102 | return g4 103 | 104 | z_dimensions = 100 105 | batch_size = 50 106 | 107 | x_placeholder = tf.placeholder(tf.float32, shape = [None,28,28,1], name='x_placeholder') 108 | # x_placeholder is for feeding input images to the discriminator 109 | 110 | Gz = generator(batch_size, z_dimensions) 111 | # Gz holds the generated images 112 | 113 | Dx = discriminator(x_placeholder) 114 | # Dx will hold discriminator prediction probabilities 115 | # for the real MNIST images 116 | 117 | Dg = discriminator(Gz, reuse_variables=True) 118 | # Dg will hold discriminator prediction probabilities for generated images 119 | 120 | # Define losses 121 | d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dx, labels = tf.ones_like(Dx))) 122 | d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.zeros_like(Dg))) 123 | g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.ones_like(Dg))) 124 | 125 | # Define variable lists 126 | tvars = tf.trainable_variables() 127 | d_vars = [var for var in tvars if 'd_' in var.name] 128 | g_vars = [var for var in tvars if 'g_' in var.name] 129 | 130 | # Define the optimizers 131 | # Train the discriminator 132 | d_trainer_fake = tf.train.AdamOptimizer(0.0003).minimize(d_loss_fake, var_list=d_vars) 133 | d_trainer_real = tf.train.AdamOptimizer(0.0003).minimize(d_loss_real, var_list=d_vars) 134 | 135 | # Train the generator 136 | g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars) 137 | 138 | # From this point forward, reuse variables 139 | tf.get_variable_scope().reuse_variables() 140 | 141 | sess = tf.Session() 142 | 143 | # Send summary statistics to TensorBoard 144 | tf.summary.scalar('Generator_loss', g_loss) 145 | tf.summary.scalar('Discriminator_loss_real', d_loss_real) 146 | tf.summary.scalar('Discriminator_loss_fake', d_loss_fake) 147 | 148 | images_for_tensorboard = generator(batch_size, z_dimensions) 149 | tf.summary.image('Generated_images', images_for_tensorboard, 5) 150 | merged = tf.summary.merge_all() 151 | logdir = "tensorboard/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "/" 152 | writer = tf.summary.FileWriter(logdir, sess.graph) 153 | 154 | sess.run(tf.global_variables_initializer()) 155 | 156 | # Pre-train discriminator 157 | for i in range(300): 158 | real_image_batch = mnist.train.next_batch(batch_size)[0].reshape([batch_size, 28, 28, 1]) 159 | _, __ = sess.run([d_trainer_real, d_trainer_fake], 160 | {x_placeholder: real_image_batch}) 161 | 162 | # Train generator and discriminator together 163 | for i in range(100000): 164 | real_image_batch = mnist.train.next_batch(batch_size)[0].reshape([batch_size, 28, 28, 1]) 165 | 166 | # Train discriminator on both real and fake images 167 | _, __ = sess.run([d_trainer_real, d_trainer_fake], 168 | {x_placeholder: real_image_batch}) 169 | 170 | # Train generator 171 | _ = sess.run(g_trainer) 172 | 173 | if i % 10 == 0: 174 | # Update TensorBoard with summary statistics 175 | summary = sess.run(merged, {x_placeholder: real_image_batch}) 176 | writer.add_summary(summary, i) 177 | 178 | # Optionally, uncomment the following lines to update the checkpoint files attached to the tutorial. 179 | # saver = tf.train.Saver() 180 | # saver.save(sess, 'pretrained-model/pretrained_gan.ckpt') 181 | -------------------------------------------------------------------------------- /gan-script.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a straightforward Python implementation of a generative adversarial network. 3 | The code is drawn directly from the O'Reilly interactive tutorial on GANs 4 | (https://www.oreilly.com/learning/generative-adversarial-networks-for-beginners). 5 | 6 | A version of this model with explanatory notes is also available on GitHub 7 | at https://github.com/jonbruner/generative-adversarial-networks. 8 | 9 | This script requires TensorFlow and its dependencies in order to run. Please see 10 | the readme for guidance on installing TensorFlow. 11 | 12 | This script won't print summary statistics in the terminal during training; 13 | track progress and see sample images in TensorBoard. 14 | """ 15 | 16 | import tensorflow as tf 17 | import numpy as np 18 | import datetime 19 | 20 | # Load MNIST data 21 | from tensorflow.examples.tutorials.mnist import input_data 22 | mnist = input_data.read_data_sets("MNIST_data/") 23 | 24 | # Define the discriminator network 25 | def discriminator(images, reuse_variables=None): 26 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables) as scope: 27 | # First convolutional and pool layers 28 | # This finds 32 different 5 x 5 pixel features 29 | d_w1 = tf.get_variable('d_w1', [5, 5, 1, 32], initializer=tf.truncated_normal_initializer(stddev=0.02)) 30 | d_b1 = tf.get_variable('d_b1', [32], initializer=tf.constant_initializer(0)) 31 | d1 = tf.nn.conv2d(input=images, filter=d_w1, strides=[1, 1, 1, 1], padding='SAME') 32 | d1 = d1 + d_b1 33 | d1 = tf.nn.relu(d1) 34 | d1 = tf.nn.avg_pool(d1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 35 | 36 | # Second convolutional and pool layers 37 | # This finds 64 different 5 x 5 pixel features 38 | d_w2 = tf.get_variable('d_w2', [5, 5, 32, 64], initializer=tf.truncated_normal_initializer(stddev=0.02)) 39 | d_b2 = tf.get_variable('d_b2', [64], initializer=tf.constant_initializer(0)) 40 | d2 = tf.nn.conv2d(input=d1, filter=d_w2, strides=[1, 1, 1, 1], padding='SAME') 41 | d2 = d2 + d_b2 42 | d2 = tf.nn.relu(d2) 43 | d2 = tf.nn.avg_pool(d2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 44 | 45 | # First fully connected layer 46 | d_w3 = tf.get_variable('d_w3', [7 * 7 * 64, 1024], initializer=tf.truncated_normal_initializer(stddev=0.02)) 47 | d_b3 = tf.get_variable('d_b3', [1024], initializer=tf.constant_initializer(0)) 48 | d3 = tf.reshape(d2, [-1, 7 * 7 * 64]) 49 | d3 = tf.matmul(d3, d_w3) 50 | d3 = d3 + d_b3 51 | d3 = tf.nn.relu(d3) 52 | 53 | # Second fully connected layer 54 | d_w4 = tf.get_variable('d_w4', [1024, 1], initializer=tf.truncated_normal_initializer(stddev=0.02)) 55 | d_b4 = tf.get_variable('d_b4', [1], initializer=tf.constant_initializer(0)) 56 | d4 = tf.matmul(d3, d_w4) + d_b4 57 | 58 | # d4 contains unscaled values 59 | return d4 60 | 61 | # Define the generator network 62 | def generator(z, batch_size, z_dim): 63 | g_w1 = tf.get_variable('g_w1', [z_dim, 3136], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02)) 64 | g_b1 = tf.get_variable('g_b1', [3136], initializer=tf.truncated_normal_initializer(stddev=0.02)) 65 | g1 = tf.matmul(z, g_w1) + g_b1 66 | g1 = tf.reshape(g1, [-1, 56, 56, 1]) 67 | g1 = tf.contrib.layers.batch_norm(g1, epsilon=1e-5, scope='g_b1') 68 | g1 = tf.nn.relu(g1) 69 | 70 | # Generate 50 features 71 | g_w2 = tf.get_variable('g_w2', [3, 3, 1, z_dim/2], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02)) 72 | g_b2 = tf.get_variable('g_b2', [z_dim/2], initializer=tf.truncated_normal_initializer(stddev=0.02)) 73 | g2 = tf.nn.conv2d(g1, g_w2, strides=[1, 2, 2, 1], padding='SAME') 74 | g2 = g2 + g_b2 75 | g2 = tf.contrib.layers.batch_norm(g2, epsilon=1e-5, scope='g_b2') 76 | g2 = tf.nn.relu(g2) 77 | g2 = tf.image.resize_images(g2, [56, 56]) 78 | 79 | # Generate 25 features 80 | g_w3 = tf.get_variable('g_w3', [3, 3, z_dim/2, z_dim/4], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02)) 81 | g_b3 = tf.get_variable('g_b3', [z_dim/4], initializer=tf.truncated_normal_initializer(stddev=0.02)) 82 | g3 = tf.nn.conv2d(g2, g_w3, strides=[1, 2, 2, 1], padding='SAME') 83 | g3 = g3 + g_b3 84 | g3 = tf.contrib.layers.batch_norm(g3, epsilon=1e-5, scope='g_b3') 85 | g3 = tf.nn.relu(g3) 86 | g3 = tf.image.resize_images(g3, [56, 56]) 87 | 88 | # Final convolution with one output channel 89 | g_w4 = tf.get_variable('g_w4', [1, 1, z_dim/4, 1], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02)) 90 | g_b4 = tf.get_variable('g_b4', [1], initializer=tf.truncated_normal_initializer(stddev=0.02)) 91 | g4 = tf.nn.conv2d(g3, g_w4, strides=[1, 2, 2, 1], padding='SAME') 92 | g4 = g4 + g_b4 93 | g4 = tf.sigmoid(g4) 94 | 95 | # Dimensions of g4: batch_size x 28 x 28 x 1 96 | return g4 97 | 98 | z_dimensions = 100 99 | batch_size = 50 100 | z_placeholder = tf.placeholder(tf.float32, [None, z_dimensions], name='z_placeholder') 101 | # z_placeholder is for feeding input noise to the generator 102 | 103 | x_placeholder = tf.placeholder(tf.float32, shape = [None,28,28,1], name='x_placeholder') 104 | # x_placeholder is for feeding input images to the discriminator 105 | 106 | Gz = generator(z_placeholder, batch_size, z_dimensions) 107 | # Gz holds the generated images 108 | 109 | Dx = discriminator(x_placeholder) 110 | # Dx will hold discriminator prediction probabilities 111 | # for the real MNIST images 112 | 113 | Dg = discriminator(Gz, reuse_variables=True) 114 | # Dg will hold discriminator prediction probabilities for generated images 115 | 116 | # Define losses 117 | d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dx, labels = tf.ones_like(Dx))) 118 | d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.zeros_like(Dg))) 119 | g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.ones_like(Dg))) 120 | 121 | # Define variable lists 122 | tvars = tf.trainable_variables() 123 | d_vars = [var for var in tvars if 'd_' in var.name] 124 | g_vars = [var for var in tvars if 'g_' in var.name] 125 | 126 | # Define the optimizers 127 | # Train the discriminator 128 | d_trainer_fake = tf.train.AdamOptimizer(0.0003).minimize(d_loss_fake, var_list=d_vars) 129 | d_trainer_real = tf.train.AdamOptimizer(0.0003).minimize(d_loss_real, var_list=d_vars) 130 | 131 | # Train the generator 132 | g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars) 133 | 134 | # From this point forward, reuse variables 135 | tf.get_variable_scope().reuse_variables() 136 | 137 | sess = tf.Session() 138 | 139 | # Send summary statistics to TensorBoard 140 | tf.summary.scalar('Generator_loss', g_loss) 141 | tf.summary.scalar('Discriminator_loss_real', d_loss_real) 142 | tf.summary.scalar('Discriminator_loss_fake', d_loss_fake) 143 | 144 | images_for_tensorboard = generator(z_placeholder, batch_size, z_dimensions) 145 | tf.summary.image('Generated_images', images_for_tensorboard, 5) 146 | merged = tf.summary.merge_all() 147 | logdir = "tensorboard/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "/" 148 | writer = tf.summary.FileWriter(logdir, sess.graph) 149 | 150 | sess.run(tf.global_variables_initializer()) 151 | 152 | # Pre-train discriminator 153 | for i in range(300): 154 | z_batch = np.random.normal(0, 1, size=[batch_size, z_dimensions]) 155 | real_image_batch = mnist.train.next_batch(batch_size)[0].reshape([batch_size, 28, 28, 1]) 156 | _, __, dLossReal, dLossFake = sess.run([d_trainer_real, d_trainer_fake, d_loss_real, d_loss_fake], 157 | {x_placeholder: real_image_batch, z_placeholder: z_batch}) 158 | 159 | # Train generator and discriminator together 160 | for i in range(100000): 161 | real_image_batch = mnist.train.next_batch(batch_size)[0].reshape([batch_size, 28, 28, 1]) 162 | z_batch = np.random.normal(0, 1, size=[batch_size, z_dimensions]) 163 | 164 | # Train discriminator on both real and fake images 165 | _, __, dLossReal, dLossFake = sess.run([d_trainer_real, d_trainer_fake, d_loss_real, d_loss_fake], 166 | {x_placeholder: real_image_batch, z_placeholder: z_batch}) 167 | 168 | # Train generator 169 | z_batch = np.random.normal(0, 1, size=[batch_size, z_dimensions]) 170 | _ = sess.run(g_trainer, feed_dict={z_placeholder: z_batch}) 171 | 172 | if i % 10 == 0: 173 | # Update TensorBoard with summary statistics 174 | z_batch = np.random.normal(0, 1, size=[batch_size, z_dimensions]) 175 | summary = sess.run(merged, {z_placeholder: z_batch, x_placeholder: real_image_batch}) 176 | writer.add_summary(summary, i) 177 | 178 | # Optionally, uncomment the following lines to update the checkpoint files attached to the tutorial. 179 | # saver = tf.train.Saver() 180 | # saver.save(sess, 'pretrained-model/pretrained_gan.ckpt') 181 | -------------------------------------------------------------------------------- /notebook-images/GAN_Discriminator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonbruner/generative-adversarial-networks/2e792d92b823f5a2c9c8095420869a539aa0819c/notebook-images/GAN_Discriminator.png -------------------------------------------------------------------------------- /notebook-images/GAN_Generator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonbruner/generative-adversarial-networks/2e792d92b823f5a2c9c8095420869a539aa0819c/notebook-images/GAN_Generator.png -------------------------------------------------------------------------------- /notebook-images/GAN_Overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonbruner/generative-adversarial-networks/2e792d92b823f5a2c9c8095420869a539aa0819c/notebook-images/GAN_Overall.png -------------------------------------------------------------------------------- /notebook-images/gan-animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonbruner/generative-adversarial-networks/2e792d92b823f5a2c9c8095420869a539aa0819c/notebook-images/gan-animation.gif -------------------------------------------------------------------------------- /pretrained-model/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "pretrained_gan.ckpt" 2 | all_model_checkpoint_paths: "pretrained_gan.ckpt" 3 | -------------------------------------------------------------------------------- /pretrained-model/pretrained_gan.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonbruner/generative-adversarial-networks/2e792d92b823f5a2c9c8095420869a539aa0819c/pretrained-model/pretrained_gan.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /pretrained-model/pretrained_gan.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonbruner/generative-adversarial-networks/2e792d92b823f5a2c9c8095420869a539aa0819c/pretrained-model/pretrained_gan.ckpt.index -------------------------------------------------------------------------------- /pretrained-model/pretrained_gan.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonbruner/generative-adversarial-networks/2e792d92b823f5a2c9c8095420869a539aa0819c/pretrained-model/pretrained_gan.ckpt.meta --------------------------------------------------------------------------------