├── .gitignore ├── LICENSE ├── README.md └── genadv1.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | genadv_tutorial is licensed under the Simplified "2-clause" BSD License: 2 | 3 | Copyright (c) 2015: Eric Jang. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 6 | 7 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 8 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # genadv_tutorial 2 | 3 | Tutorial on Generative Adversarial Models. See the [blog post](http://blog.evjang.com/2016/06/generative-adversarial-nets-in.html). 4 | 5 | Eric Jang 6 | 7 | 30 Dec 2015 8 | 9 | License: BSD Clause 2 10 | -------------------------------------------------------------------------------- /genadv1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Generative Adversarial Nets\n", 8 | "\n", 9 | "Training a generative adversarial network to sample from a Gaussian distribution. This is a toy problem, takes < 3 minutes to run on a modest 1.2GHz CPU." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stderr", 19 | "output_type": "stream", 20 | "text": [ 21 | "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", 22 | " from ._conv import register_converters as _register_converters\n" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "import tensorflow as tf\n", 28 | "import numpy as np\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "import seaborn as sns # for pretty plots\n", 31 | "from scipy.stats import norm\n", 32 | "%matplotlib inline" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "Target distribution $p_{data}$" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 2, 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "data": { 49 | "text/plain": [ 50 | "[]" 51 | ] 52 | }, 53 | "execution_count": 2, 54 | "metadata": {}, 55 | "output_type": "execute_result" 56 | }, 57 | { 58 | "data": { 59 | "image/png": "\n", 60 | "text/plain": [ 61 | "
" 62 | ] 63 | }, 64 | "metadata": {}, 65 | "output_type": "display_data" 66 | } 67 | ], 68 | "source": [ 69 | "mu,sigma=-1,1\n", 70 | "xs=np.linspace(-5,5,1000)\n", 71 | "plt.plot(xs, norm.pdf(xs,loc=mu,scale=sigma))\n", 72 | "#plt.savefig('fig0.png')" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "TRAIN_ITERS=10000\n", 82 | "M=200 # minibatch size" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 4, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "# MLP - used for D_pre, D1, D2, G networks\n", 92 | "def mlp(input, output_dim):\n", 93 | " # construct learnable parameters within local scope\n", 94 | " w1=tf.get_variable(\"w0\", [input.get_shape()[1], 6], initializer=tf.random_normal_initializer())\n", 95 | " b1=tf.get_variable(\"b0\", [6], initializer=tf.constant_initializer(0.0))\n", 96 | " w2=tf.get_variable(\"w1\", [6, 5], initializer=tf.random_normal_initializer())\n", 97 | " b2=tf.get_variable(\"b1\", [5], initializer=tf.constant_initializer(0.0))\n", 98 | " w3=tf.get_variable(\"w2\", [5,output_dim], initializer=tf.random_normal_initializer())\n", 99 | " b3=tf.get_variable(\"b2\", [output_dim], initializer=tf.constant_initializer(0.0))\n", 100 | " # nn operators\n", 101 | " fc1=tf.nn.tanh(tf.matmul(input,w1)+b1)\n", 102 | " fc2=tf.nn.tanh(tf.matmul(fc1,w2)+b2)\n", 103 | " fc3=tf.nn.tanh(tf.matmul(fc2,w3)+b3)\n", 104 | " return fc3, [w1,b1,w2,b2,w3,b3]" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 5, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "# re-used for optimizing all networks\n", 114 | "def momentum_optimizer(loss,var_list):\n", 115 | " batch = tf.Variable(0)\n", 116 | " learning_rate = tf.train.exponential_decay(\n", 117 | " 0.001, # Base learning rate.\n", 118 | " batch, # Current index into the dataset.\n", 119 | " TRAIN_ITERS // 4, # Decay step - this decays 4 times throughout training process.\n", 120 | " 0.95, # Decay rate.\n", 121 | " staircase=True)\n", 122 | " #optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=batch,var_list=var_list)\n", 123 | " optimizer=tf.train.MomentumOptimizer(learning_rate,0.6).minimize(loss,global_step=batch,var_list=var_list)\n", 124 | " return optimizer" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "# Pre-train Decision Surface\n", 132 | "\n", 133 | "If decider is reasonably accurate to start, we get much faster convergence." 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 6, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "with tf.variable_scope(\"D_pre\"):\n", 143 | " input_node=tf.placeholder(tf.float32, shape=(M,1))\n", 144 | " train_labels=tf.placeholder(tf.float32,shape=(M,1))\n", 145 | " D,theta=mlp(input_node,1)\n", 146 | " loss=tf.reduce_mean(tf.square(D-train_labels))" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 7, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "optimizer=momentum_optimizer(loss,None)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 8, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "sess=tf.InteractiveSession()\n", 165 | "tf.global_variables_initializer().run()" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 9, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "# plot decision surface\n", 175 | "def plot_d0(D,input_node):\n", 176 | " f,ax=plt.subplots(1)\n", 177 | " # p_data\n", 178 | " xs=np.linspace(-5,5,1000)\n", 179 | " ax.plot(xs, norm.pdf(xs,loc=mu,scale=sigma), label='p_data')\n", 180 | " # decision boundary\n", 181 | " r=1000 # resolution (number of points)\n", 182 | " xs=np.linspace(-5,5,r)\n", 183 | " ds=np.zeros((r,1)) # decision surface\n", 184 | " # process multiple points in parallel in a minibatch\n", 185 | " for i in range(r//M):\n", 186 | " x=np.reshape(xs[M*i:M*(i+1)],(M,1))\n", 187 | " ds[M*i:M*(i+1)]=sess.run(D,{input_node: x})\n", 188 | "\n", 189 | " ax.plot(xs, ds, label='decision boundary')\n", 190 | " ax.set_ylim(0,1.1)\n", 191 | " plt.legend()" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 10, 197 | "metadata": {}, 198 | "outputs": [ 199 | { 200 | "data": { 201 | "text/plain": [ 202 | "Text(0.5,1,'Initial Decision Boundary')" 203 | ] 204 | }, 205 | "execution_count": 10, 206 | "metadata": {}, 207 | "output_type": "execute_result" 208 | }, 209 | { 210 | "data": { 211 | "image/png": "\n", 212 | "text/plain": [ 213 | "
" 214 | ] 215 | }, 216 | "metadata": {}, 217 | "output_type": "display_data" 218 | } 219 | ], 220 | "source": [ 221 | "plot_d0(D,input_node)\n", 222 | "plt.title('Initial Decision Boundary')\n", 223 | "#plt.savefig('fig1.png')" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 11, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "lh=np.zeros(1000)\n", 233 | "for i in range(1000):\n", 234 | " #d=np.random.normal(mu,sigma,M)\n", 235 | " d=(np.random.random(M)-0.5) * 10.0 # instead of sampling only from gaussian, want the domain to be covered as uniformly as possible\n", 236 | " labels=norm.pdf(d,loc=mu,scale=sigma)\n", 237 | " lh[i],_=sess.run([loss,optimizer], {input_node: np.reshape(d,(M,1)), train_labels: np.reshape(labels,(M,1))})" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 12, 243 | "metadata": { 244 | "scrolled": true 245 | }, 246 | "outputs": [ 247 | { 248 | "data": { 249 | "text/plain": [ 250 | "Text(0.5,1,'Training Loss')" 251 | ] 252 | }, 253 | "execution_count": 12, 254 | "metadata": {}, 255 | "output_type": "execute_result" 256 | }, 257 | { 258 | "data": { 259 | "image/png": "\n", 260 | "text/plain": [ 261 | "
" 262 | ] 263 | }, 264 | "metadata": {}, 265 | "output_type": "display_data" 266 | } 267 | ], 268 | "source": [ 269 | "# training loss\n", 270 | "plt.plot(lh)\n", 271 | "plt.title('Training Loss')" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 13, 277 | "metadata": {}, 278 | "outputs": [ 279 | { 280 | "data": { 281 | "image/png": "\n", 282 | "text/plain": [ 283 | "
" 284 | ] 285 | }, 286 | "metadata": {}, 287 | "output_type": "display_data" 288 | } 289 | ], 290 | "source": [ 291 | "plot_d0(D,input_node)\n", 292 | "#plt.savefig('fig2.png')" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 14, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "# copy the learned weights over into a tmp array\n", 302 | "weightsD=sess.run(theta)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 15, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "# close the pre-training session\n", 312 | "sess.close()" 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": {}, 318 | "source": [ 319 | "# Build Net\n", 320 | "\n", 321 | "Now to build the actual generative adversarial network" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 16, 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "with tf.variable_scope(\"G\"):\n", 331 | " z_node=tf.placeholder(tf.float32, shape=(M,1)) # M uniform01 floats\n", 332 | " G,theta_g=mlp(z_node,1) # generate normal transformation of Z\n", 333 | " G=tf.multiply(5.0,G) # scale up by 5 to match range\n", 334 | "with tf.variable_scope(\"D\") as scope:\n", 335 | " # D(x)\n", 336 | " x_node=tf.placeholder(tf.float32, shape=(M,1)) # input M normally distributed floats\n", 337 | " fc,theta_d=mlp(x_node,1) # output likelihood of being normally distributed\n", 338 | " D1=tf.maximum(tf.minimum(fc,.99), 0.01) # clamp as a probability\n", 339 | " # make a copy of D that uses the same variables, but takes in G as input\n", 340 | " scope.reuse_variables()\n", 341 | " fc,theta_d=mlp(G,1)\n", 342 | " D2=tf.maximum(tf.minimum(fc,.99), 0.01)\n", 343 | "obj_d=tf.reduce_mean(tf.log(D1)+tf.log(1-D2))\n", 344 | "obj_g=tf.reduce_mean(tf.log(D2))\n", 345 | "\n", 346 | "# set up optimizer for G,D\n", 347 | "opt_d=momentum_optimizer(1-obj_d, theta_d)\n", 348 | "opt_g=momentum_optimizer(1-obj_g, theta_g) # maximize log(D(G(z)))" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": 17, 354 | "metadata": {}, 355 | "outputs": [], 356 | "source": [ 357 | "sess=tf.InteractiveSession()\n", 358 | "tf.global_variables_initializer().run()" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 18, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "# copy weights from pre-training over to new D network\n", 368 | "for i,v in enumerate(theta_d):\n", 369 | " sess.run(v.assign(weightsD[i]))" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 19, 375 | "metadata": {}, 376 | "outputs": [], 377 | "source": [ 378 | "def plot_fig():\n", 379 | " # plots pg, pdata, decision boundary \n", 380 | " f,ax=plt.subplots(1)\n", 381 | " # p_data\n", 382 | " xs=np.linspace(-5,5,1000)\n", 383 | " ax.plot(xs, norm.pdf(xs,loc=mu,scale=sigma), label='p_data')\n", 384 | "\n", 385 | " # decision boundary\n", 386 | " r=5000 # resolution (number of points)\n", 387 | " xs=np.linspace(-5,5,r)\n", 388 | " ds=np.zeros((r,1)) # decision surface\n", 389 | " # process multiple points in parallel in same minibatch\n", 390 | " for i in range(r//M):\n", 391 | " x=np.reshape(xs[M*i:M*(i+1)],(M,1))\n", 392 | " ds[M*i:M*(i+1)]=sess.run(D1,{x_node: x})\n", 393 | "\n", 394 | " ax.plot(xs, ds, label='decision boundary')\n", 395 | "\n", 396 | " # distribution of inverse-mapped points\n", 397 | " zs=np.linspace(-5,5,r)\n", 398 | " gs=np.zeros((r,1)) # generator function\n", 399 | " for i in range(r//M):\n", 400 | " z=np.reshape(zs[M*i:M*(i+1)],(M,1))\n", 401 | " gs[M*i:M*(i+1)]=sess.run(G,{z_node: z})\n", 402 | " histc, edges = np.histogram(gs, bins = 10)\n", 403 | " ax.plot(np.linspace(-5,5,10), histc/float(r), label='p_g')\n", 404 | "\n", 405 | " # ylim, legend\n", 406 | " ax.set_ylim(0,1.1)\n", 407 | " plt.legend()" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 20, 413 | "metadata": {}, 414 | "outputs": [ 415 | { 416 | "data": { 417 | "text/plain": [ 418 | "Text(0.5,1,'Before Training')" 419 | ] 420 | }, 421 | "execution_count": 20, 422 | "metadata": {}, 423 | "output_type": "execute_result" 424 | }, 425 | { 426 | "data": { 427 | "image/png": "\n", 428 | "text/plain": [ 429 | "
" 430 | ] 431 | }, 432 | "metadata": {}, 433 | "output_type": "display_data" 434 | } 435 | ], 436 | "source": [ 437 | "# initial conditions\n", 438 | "plot_fig()\n", 439 | "plt.title('Before Training')\n", 440 | "#plt.savefig('fig3.png')" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 21, 446 | "metadata": {}, 447 | "outputs": [ 448 | { 449 | "name": "stdout", 450 | "output_type": "stream", 451 | "text": [ 452 | "0.0\n", 453 | "0.1\n", 454 | "0.2\n", 455 | "0.3\n", 456 | "0.4\n", 457 | "0.5\n", 458 | "0.6\n", 459 | "0.7\n", 460 | "0.8\n", 461 | "0.9\n" 462 | ] 463 | } 464 | ], 465 | "source": [ 466 | "# Algorithm 1 of Goodfellow et al 2014\n", 467 | "k=1\n", 468 | "histd, histg= np.zeros(TRAIN_ITERS), np.zeros(TRAIN_ITERS)\n", 469 | "for i in range(TRAIN_ITERS):\n", 470 | " for j in range(k):\n", 471 | " x= np.random.normal(mu,sigma,M) # sampled m-batch from p_data\n", 472 | " x.sort()\n", 473 | " z= np.linspace(-5.0,5.0,M)+np.random.random(M)*0.01 # sample m-batch from noise prior\n", 474 | " histd[i],_=sess.run([obj_d,opt_d], {x_node: np.reshape(x,(M,1)), z_node: np.reshape(z,(M,1))})\n", 475 | " z= np.linspace(-5.0,5.0,M)+np.random.random(M)*0.01 # sample noise prior\n", 476 | " histg[i],_=sess.run([obj_g,opt_g], {z_node: np.reshape(z,(M,1))}) # update generator\n", 477 | " if i % (TRAIN_ITERS//10) == 0:\n", 478 | " print(float(i)/float(TRAIN_ITERS))" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 22, 484 | "metadata": {}, 485 | "outputs": [ 486 | { 487 | "data": { 488 | "text/plain": [ 489 | "" 490 | ] 491 | }, 492 | "execution_count": 22, 493 | "metadata": {}, 494 | "output_type": "execute_result" 495 | }, 496 | { 497 | "data": { 498 | "image/png": "\n", 499 | "text/plain": [ 500 | "
" 501 | ] 502 | }, 503 | "metadata": {}, 504 | "output_type": "display_data" 505 | } 506 | ], 507 | "source": [ 508 | "plt.plot(range(TRAIN_ITERS),histd, label='obj_d')\n", 509 | "plt.plot(range(TRAIN_ITERS), 1-histg, label='obj_g')\n", 510 | "plt.legend()\n", 511 | "#plt.savefig('fig4.png')" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": 23, 517 | "metadata": {}, 518 | "outputs": [ 519 | { 520 | "data": { 521 | "image/png": "\n", 522 | "text/plain": [ 523 | "
" 524 | ] 525 | }, 526 | "metadata": {}, 527 | "output_type": "display_data" 528 | } 529 | ], 530 | "source": [ 531 | "plot_fig()\n", 532 | "#plt.savefig('fig5.png')" 533 | ] 534 | } 535 | ], 536 | "metadata": { 537 | "kernelspec": { 538 | "display_name": "Python 3", 539 | "language": "python", 540 | "name": "python3" 541 | }, 542 | "language_info": { 543 | "codemirror_mode": { 544 | "name": "ipython", 545 | "version": 3 546 | }, 547 | "file_extension": ".py", 548 | "mimetype": "text/x-python", 549 | "name": "python", 550 | "nbconvert_exporter": "python", 551 | "pygments_lexer": "ipython3", 552 | "version": "3.6.5" 553 | } 554 | }, 555 | "nbformat": 4, 556 | "nbformat_minor": 1 557 | } 558 | --------------------------------------------------------------------------------