├── README.md └── gan-image-classifier.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Image Classification GAN (work-in-progress) 2 | In this project I will build a GAN (generative adversarial network) for image classification using semi-supervised learning. I am using the SVHN dataset dataset for training. 3 | -------------------------------------------------------------------------------- /gan-image-classifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# GANs for Image Classification\n", 8 | "\n", 9 | "In this notebook, I will use GANs for image classification using semi-supervised learning.\n", 10 | "\n", 11 | "In supervised learning, we have a training set of inputs $x$ and class labels $y$. We train a model that takes $x$ as input and gives $y$ as output.\n", 12 | "\n", 13 | "In semi-supervised learning, our goal is still to train a model that takes $x$ as input and generates $y$ as output. However, not all of our training examples have a label $y$. We need to develop an algorithm that is able to get better at classification by studying both labeled $(x, y)$ pairs and unlabeled $x$ examples.\n", 14 | "\n", 15 | "In this example I will be using the SVHN dataset. First I will turn the GAN discriminator into an 11 class discriminator. It will recognize the 10 different classes of real SVHN digits, as well as an 11th class of fake images that come from the generator. The discriminator will get to train on real labeled images, real unlabeled images, and fake images. By drawing on three sources of data instead of just one, it will generalize to the test set much better than a traditional classifier trained on only one source of data." 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": { 22 | "collapsed": true 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "%matplotlib inline\n", 27 | "\n", 28 | "import pickle as pkl\n", 29 | "import time\n", 30 | "\n", 31 | "import matplotlib.pyplot as plt\n", 32 | "import numpy as np\n", 33 | "from scipy.io import loadmat\n", 34 | "import tensorflow as tf\n" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "collapsed": true 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "!mkdir data" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": { 52 | "collapsed": true 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "from urllib.request import urlretrieve\n", 57 | "from os.path import isfile, isdir\n", 58 | "from tqdm import tqdm\n", 59 | "\n", 60 | "data_dir = 'data/'\n", 61 | "\n", 62 | "if not isdir(data_dir):\n", 63 | " raise Exception(\"Data directory doesn't exist!\")\n", 64 | "\n", 65 | "class DLProgress(tqdm):\n", 66 | " last_block = 0\n", 67 | "\n", 68 | " def hook(self, block_num=1, block_size=1, total_size=None):\n", 69 | " self.total = total_size\n", 70 | " self.update((block_num - self.last_block) * block_size)\n", 71 | " self.last_block = block_num\n", 72 | "\n", 73 | "if not isfile(data_dir + \"train_32x32.mat\"):\n", 74 | " with DLProgress(unit='B', unit_scale=True, miniters=1, desc='SVHN Training Set') as pbar:\n", 75 | " urlretrieve(\n", 76 | " 'http://ufldl.stanford.edu/housenumbers/train_32x32.mat',\n", 77 | " data_dir + 'train_32x32.mat',\n", 78 | " pbar.hook)\n", 79 | "\n", 80 | "if not isfile(data_dir + \"test_32x32.mat\"):\n", 81 | " with DLProgress(unit='B', unit_scale=True, miniters=1, desc='SVHN Training Set') as pbar:\n", 82 | " urlretrieve(\n", 83 | " 'http://ufldl.stanford.edu/housenumbers/test_32x32.mat',\n", 84 | " data_dir + 'test_32x32.mat',\n", 85 | " pbar.hook)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "collapsed": true 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "trainset = loadmat(data_dir + 'train_32x32.mat')\n", 97 | "testset = loadmat(data_dir + 'test_32x32.mat')" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": { 104 | "collapsed": true 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "idx = np.random.randint(0, trainset['X'].shape[3], size=36)\n", 109 | "fig, axes = plt.subplots(6, 6, sharex=True, sharey=True, figsize=(5,5),)\n", 110 | "for ii, ax in zip(idx, axes.flatten()):\n", 111 | " ax.imshow(trainset['X'][:,:,:,ii], aspect='equal')\n", 112 | " ax.xaxis.set_visible(False)\n", 113 | " ax.yaxis.set_visible(False)\n", 114 | "plt.subplots_adjust(wspace=0, hspace=0)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": { 121 | "collapsed": true 122 | }, 123 | "outputs": [], 124 | "source": [ 125 | "def scale(x, feature_range=(-1, 1)):\n", 126 | " # scale to (0, 1)\n", 127 | " x = ((x - x.min())/(255 - x.min()))\n", 128 | " \n", 129 | " # scale to feature_range\n", 130 | " min, max = feature_range\n", 131 | " x = x * (max - min) + min\n", 132 | " return x" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "collapsed": true 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "class Dataset:\n", 144 | " def __init__(self, train, test, val_frac=0.5, shuffle=True, scale_func=None):\n", 145 | " split_idx = int(len(test['y'])*(1 - val_frac))\n", 146 | " self.test_x, self.valid_x = test['X'][:,:,:,:split_idx], test['X'][:,:,:,split_idx:]\n", 147 | " self.test_y, self.valid_y = test['y'][:split_idx], test['y'][split_idx:]\n", 148 | " self.train_x, self.train_y = train['X'], train['y']\n", 149 | " # The SVHN dataset comes with lots of labels, but for the purpose of this exercise,\n", 150 | " # we will pretend that there are only 1000.\n", 151 | " # We use this mask to say which labels we will allow ourselves to use.\n", 152 | " self.label_mask = np.zeros_like(self.train_y)\n", 153 | " self.label_mask[0:1000] = 1\n", 154 | " \n", 155 | " self.train_x = np.rollaxis(self.train_x, 3)\n", 156 | " self.valid_x = np.rollaxis(self.valid_x, 3)\n", 157 | " self.test_x = np.rollaxis(self.test_x, 3)\n", 158 | " \n", 159 | " if scale_func is None:\n", 160 | " self.scaler = scale\n", 161 | " else:\n", 162 | " self.scaler = scale_func\n", 163 | " self.train_x = self.scaler(self.train_x)\n", 164 | " self.valid_x = self.scaler(self.valid_x)\n", 165 | " self.test_x = self.scaler(self.test_x)\n", 166 | " self.shuffle = shuffle\n", 167 | " \n", 168 | " def batches(self, batch_size, which_set=\"train\"):\n", 169 | " x_name = which_set + \"_x\"\n", 170 | " y_name = which_set + \"_y\"\n", 171 | " \n", 172 | " num_examples = len(getattr(dataset, y_name))\n", 173 | " if self.shuffle:\n", 174 | " idx = np.arange(num_examples)\n", 175 | " np.random.shuffle(idx)\n", 176 | " setattr(dataset, x_name, getattr(dataset, x_name)[idx])\n", 177 | " setattr(dataset, y_name, getattr(dataset, y_name)[idx])\n", 178 | " if which_set == \"train\":\n", 179 | " dataset.label_mask = dataset.label_mask[idx]\n", 180 | " \n", 181 | " dataset_x = getattr(dataset, x_name)\n", 182 | " dataset_y = getattr(dataset, y_name)\n", 183 | " for ii in range(0, num_examples, batch_size):\n", 184 | " x = dataset_x[ii:ii+batch_size]\n", 185 | " y = dataset_y[ii:ii+batch_size]\n", 186 | " \n", 187 | " if which_set == \"train\":\n", 188 | " # When we use the data for training, we need to include\n", 189 | " # the label mask, so we can pretend we don't have access\n", 190 | " # to some of the labels, as an exercise of our semi-supervised\n", 191 | " # learning ability\n", 192 | " yield x, y, self.label_mask[ii:ii+batch_size]\n", 193 | " else:\n", 194 | " yield x, y" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": { 201 | "collapsed": true 202 | }, 203 | "outputs": [], 204 | "source": [ 205 | "def model_inputs(real_dim, z_dim):\n", 206 | " inputs_real = tf.placeholder(tf.float32, (None, *real_dim), name='input_real')\n", 207 | " inputs_z = tf.placeholder(tf.float32, (None, z_dim), name='input_z')\n", 208 | " y = tf.placeholder(tf.int32, (None), name='y')\n", 209 | " label_mask = tf.placeholder(tf.int32, (None), name='label_mask')\n", 210 | " \n", 211 | " return inputs_real, inputs_z, y, label_mask" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": { 218 | "collapsed": true 219 | }, 220 | "outputs": [], 221 | "source": [ 222 | "def generator(z, output_dim, reuse=False, alpha=0.2, training=True, size_mult=128)\n", 223 | " with tf.variable_scope('generator', reuse=reuse):\n", 224 | " # First fully connected layer\n", 225 | " x1 = tf.layers.dense(z, 4 * 4 * size_mult * 4)\n", 226 | " # Reshape it to start the convolutional stack\n", 227 | " x1 = tf.reshape(x1, (-1, 4, 4, size_mult * 4))\n", 228 | " x1 = tf.layers.batch_normalization(x1, training=training)\n", 229 | " x1 = tf.maximum(alpha * x1, x1)\n", 230 | " \n", 231 | " x2 = tf.layers.conv2d_transpose(x1, size_mult * 2, 5, strides=2, padding='same')\n", 232 | " x2 = tf.layers.batch_normalization(x2, training=training)\n", 233 | " x2 = tf.maximum(alpha * x2, x2)\n", 234 | " \n", 235 | " x3 = tf.layers.conv2d_transpose(x2, size_mult, 5, strides=2, padding='same')\n", 236 | " x3 = tf.layers.batch_normalization(x3, training=training)\n", 237 | " x3 = tf.maximum(alpha * x3, x3)\n", 238 | "\n", 239 | " # Output layer\n", 240 | " logits = tf.layers.conv2d_transpose(x3, output_dim, 5, strides=2, padding='same')\n", 241 | "\n", 242 | " out = tf.tanh(logits)\n", 243 | "\n", 244 | " return out" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": { 251 | "collapsed": true 252 | }, 253 | "outputs": [], 254 | "source": [ 255 | "def discriminator(x, reuse=False, alpha=0.2, drop_rate=0., num_classes=10, size_mult=64)\n", 256 | "#define discriminator function" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": { 263 | "collapsed": true 264 | }, 265 | "outputs": [], 266 | "source": [ 267 | "def model_loss(input_real, input_z, output_dim, y, num_classes, label_mask, alpha=0.2, drop_rate=0.)\n", 268 | "#define model loss" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "metadata": { 275 | "collapsed": true 276 | }, 277 | "outputs": [], 278 | "source": [ 279 | "def model_opt(d_loss, g_loss, learning_rate, beta1)\n", 280 | "#define optimizer" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": { 287 | "collapsed": true 288 | }, 289 | "outputs": [], 290 | "source": [ 291 | "class GAN\n", 292 | "#define GAN class" 293 | ] 294 | } 295 | ], 296 | "metadata": { 297 | "kernelspec": { 298 | "display_name": "Python 3", 299 | "language": "python", 300 | "name": "python3" 301 | }, 302 | "language_info": { 303 | "codemirror_mode": { 304 | "name": "ipython", 305 | "version": 3 306 | }, 307 | "file_extension": ".py", 308 | "mimetype": "text/x-python", 309 | "name": "python", 310 | "nbconvert_exporter": "python", 311 | "pygments_lexer": "ipython3", 312 | "version": "3.5.2" 313 | } 314 | }, 315 | "nbformat": 4, 316 | "nbformat_minor": 2 317 | } 318 | --------------------------------------------------------------------------------