├── README.md └── demo2.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # discoGAN 2 | 3 | ## Overview 4 | 5 | This is the code for [this](https://youtu.be/MgdAe-T8obE) video on Youtube by Siraj Raval as part of the Udacity Deep Learning nanodegree. We're going to use a relatively new variant of the generative adversarial network called the [discogan](https://arxiv.org/abs/1703.05192). This will allow us to perfrom style transfer, that is generate an image in the style of another. Very cool stuff. 6 | 7 | ## Dependencies 8 | 9 | * tensorflow 10 | * matplotlib 11 | * numpy 12 | * tqdm 13 | 14 | install dependencies using [pip](https://pip.pypa.io/en/stable/) 15 | 16 | ## Usage 17 | 18 | run `jupyter notebook` in terminal to see the code pop up in your browser 19 | 20 | install juypter [here](http://jupyter.readthedocs.io/en/latest/install.html) if you don't have it. 21 | 22 | ## Credits 23 | 24 | The credits for this code go to [chunyuan](https://github.com/ChunyuanLI/DiscoGAN) i've merely created a wrapper to get people started. 25 | -------------------------------------------------------------------------------- /demo2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": false 7 | }, 8 | "source": [ 9 | "# DiscoGAN\n", 10 | "\n", 11 | "Learning to DISCOver Cross Domain Relations with GANs\n", 12 | "https://www.youtube.com/watch?v=9reHvktowLY\n", 13 | "\n", 14 | "![alt text](https://pbs.twimg.com/media/C7NDNRuXgAAfePz.jpg \"Logo Title Text 1\")\n", 15 | "\n", 16 | "![alt text](http://www.aimechanic.com/wp-content/uploads/2017/03/PyTorch-DiscoGAN.png \"Logo Title Text 1\")\n", 17 | "\n", 18 | "\n", 19 | "Cross-domain relations are natural to us humans.\n", 20 | "- Suit jacket goes with dress shoes\n", 21 | "- english-french translation\n", 22 | "\n", 23 | "Can they be natural to machines?\n", 24 | "It's a conditional image generation problem. \n", 25 | "i.e find a mapping function from one domain to the other \n", 26 | "i.e generate an image in one domain given another image in the other domain. \n", 27 | "\n", 28 | "most of today’s training approaches use explicitly paired\n", 29 | "data, provided by human or another algorithm.\n", 30 | "\n", 31 | "Let's do it with no labels :)\n", 32 | "use cases\n", 33 | "- games\n", 34 | "- design with real time feedback\n", 35 | "\n", 36 | "\n", 37 | "take 1 image and reconstruct in the style of another\n", 38 | "- encoder-decoder? too naive (more like a camera filter) \n", 39 | "- 2 encoder-decoders? backwards compatible style transfer, but still naive\n", 40 | "- 2 encoder-decorers in an adversarial context? Bingo. :) \n" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 23, 46 | "metadata": { 47 | "collapsed": true 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "#why GMM?\n", 52 | "#why graph replace for discimrinator?\n", 53 | "#slim repeat operation\n", 54 | "#add 2 references\n", 55 | "\n", 56 | "\n", 57 | "#bridging the python 2 and python 3 gap\n", 58 | "from __future__ import absolute_import\n", 59 | "from __future__ import division\n", 60 | "from __future__ import print_function\n", 61 | "\n", 62 | "import os # saving files\n", 63 | "import numpy as np #matrix math\n", 64 | "\n", 65 | "#visualizing data\n", 66 | "import matplotlib.pyplot as plt\n", 67 | "import matplotlib.cm as cm\n", 68 | "\n", 69 | "#machine learning\n", 70 | "import tensorflow as tf\n", 71 | "\n", 72 | "#gausian mixture model for generating data\n", 73 | "from data_gmm import GMM_distribution, sample_GMM, plot_GMM\n", 74 | "#analyzing data \n", 75 | "from data_utils import shuffle, iter_data\n", 76 | "\n", 77 | "#progress bar\n", 78 | "from tqdm import tqdm\n", 79 | "\n", 80 | "#TF-Slim is a lightweight library for defining, training and evaluating models in TensorFlow. It enables defining complex networks quickly and concisely\n", 81 | "slim = tf.contrib.slim\n", 82 | "\n", 83 | "#Classes that represent batches of statistical distributions. \n", 84 | "#Each class is initialized with parameters that define the distributions\n", 85 | "ds = tf.contrib.distributions\n", 86 | "\n", 87 | "#Create a new graph which compute the targets from the replaced Tensors.\n", 88 | "\n", 89 | "graph_replace = tf.contrib.graph_editor.graph_replace\n" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 12, 95 | "metadata": { 96 | "collapsed": true 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "#hyperparams\n", 101 | "\"\"\" parameters \"\"\"\n", 102 | "n_epoch = 1000 #number of epcohs\n", 103 | "batch_size = 64\n", 104 | "dataset_size = 512\n", 105 | "input_dim = 2 #data and labels\n", 106 | "latent_dim = 2 \n", 107 | "eps_dim = 2\n", 108 | "\n", 109 | "\n", 110 | "#discriminator\n", 111 | "n_layer_disc = 2\n", 112 | "n_hidden_disc = 256\n", 113 | "\n", 114 | "#generator \n", 115 | "n_layer_gen = 2\n", 116 | "n_hidden_gen= 256\n", 117 | "\n", 118 | "#inference network (generator #2)\n", 119 | "n_layer_inf = 2\n", 120 | "n_hidden_inf= 256\n" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 13, 126 | "metadata": { 127 | "collapsed": true 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "#save our results to the DiscoGAN folder\n", 132 | "\"\"\" Create directory for results \"\"\"\n", 133 | "result_dir = 'results/DiscoGAN/'\n", 134 | "directory = result_dir\n", 135 | "if not os.path.exists(directory):\n", 136 | " os.makedirs(directory)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "![alt text](https://image.slidesharecdn.com/sampleproject-140814212447-phpapp02/95/speaker-recognition-using-gaussian-mixture-model-2-638.jpg?cb=1408051684\n", 144 | " \"Logo Title Text 1\")\n", 145 | " \n", 146 | "A Gaussian mixture model is a probabilistic model that assumes all the data points are generated from a mixture of a finite number of Gaussian distributions with unknown parameters.\n", 147 | "\n", 148 | "![alt text](http://i.imgur.com/GJhzOUy.png \"Logo Title Text 1\")\n", 149 | "\n", 150 | "\n", 151 | "- X = Dataset of n elements \n", 152 | "- alpha = Mixing weight of the kth component. \n", 153 | "- sigma = Gaussian probability density function\n", 154 | "- mu = Mean of the kth component.\n", 155 | "- sigma2 = Variance of the kth component.\n" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 14, 161 | "metadata": { 162 | "collapsed": false 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "#The demo is tested a toy dataset, \n", 167 | "#5-component GMM\n", 168 | "#A Gaussian mixture model is a probabilistic model \n", 169 | "#that assumes all the data points are generated from \n", 170 | "#a mixture of a finite number of Gaussian distributions \n", 171 | "#with unknown parameters.\n", 172 | "\n", 173 | "# create X dataset (first dataset)\n", 174 | "#applies a function to all the items in an input_list\n", 175 | "#lambda = anonymous functions (i.e. function that is not bound to a name)\n", 176 | "#creates a numpy array of 5 components\n", 177 | "means = map(lambda x: np.array(x), [[0, 0],\n", 178 | " [2, 2],\n", 179 | " [-1, -1],\n", 180 | " [1, -1],\n", 181 | " [-1, 1]])\n", 182 | "\n", 183 | "#convert to list to access methods\n", 184 | "means = list(means)\n", 185 | "#standard deviation\n", 186 | "std = 0.1\n", 187 | "#variances - eye Return an identiy matrix, 2-D array with 1s on the diagonal & 0s elsewhere.\n", 188 | "variances = [np.eye(2) * std for _ in means]\n", 189 | "\n", 190 | "# the probability distribution that would express one's beliefs about this \n", 191 | "#quantity before some evidence is taken into account\n", 192 | "priors = [1.0/len(means) for _ in means]\n", 193 | "\n", 194 | "#create gaussian mixture model \n", 195 | "gaussian_mixture = GMM_distribution(means=means,\n", 196 | " variances=variances,\n", 197 | " priors=priors)\n", 198 | "\n", 199 | "#sample from the data using the GMM\n", 200 | "dataset = sample_GMM(dataset_size, means, variances, priors, sources=('features', ))\n", 201 | "\n", 202 | "#save the results\n", 203 | "save_path = result_dir + 'X_gmm_data.pdf'\n", 204 | "#plot the results\n", 205 | "plot_GMM(dataset, save_path)\n", 206 | "\n", 207 | "#store data and labels\n", 208 | "X_np_data= dataset.data['samples']\n", 209 | "X_labels = dataset.data['label']" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 15, 215 | "metadata": { 216 | "collapsed": true 217 | }, 218 | "outputs": [], 219 | "source": [ 220 | "# create Z dataset (second dataset)\n", 221 | "#2-component GMM. \n", 222 | "means = map(lambda x: np.array(x), [[-1, -1],[1, 1]])\n", 223 | "means = list(means)\n", 224 | "std = 0.1\n", 225 | "variances = [np.eye(2) * std for _ in means]\n", 226 | "\n", 227 | "priors = [1.0/len(means) for _ in means]\n", 228 | "\n", 229 | "gaussian_mixture = GMM_distribution(means=means,\n", 230 | " variances=variances,\n", 231 | " priors=priors)\n", 232 | "dataset = sample_GMM(dataset_size, means, variances, priors, sources=('features', ))\n", 233 | "save_path = result_dir + 'Z_gmm_data.pdf'\n", 234 | "plot_GMM(dataset, save_path)\n", 235 | "\n", 236 | "Z_np_data= dataset.data['samples']\n", 237 | "Z_labels = dataset.data['label']" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 16, 243 | "metadata": { 244 | "collapsed": true 245 | }, 246 | "outputs": [], 247 | "source": [ 248 | "# samples of x and z\n", 249 | "X_dataset = X_np_data\n", 250 | "Z_dataset = Z_np_data" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": {}, 256 | "source": [ 257 | "![alt text](https://pbs.twimg.com/media/C8QiTe2XcAATDfn.jpg \"Logo Title Text 1\")\n", 258 | "\n", 259 | "![alt text](http://i.imgur.com/fkEbXXX.png \"Logo Title Text 1\")\n", 260 | ". Illustration of our models on simplified one dimensional domains. (a) ideal mapping from domain A to domain B in which the\n", 261 | "two domain A modes map to two different domain B modes, (b) GAN model failure case, (c) GAN with reconstruction model failure\n", 262 | "case.\n", 263 | "\n", 264 | "\n", 265 | "\n", 266 | "2 coupled models learn the mapping from one domain to another \n", 267 | "as well as the reverse mapping for reconstruction. \n", 268 | "\n", 269 | "The two models are trained together simultaneously.\n", 270 | "\n", 271 | "- 4 generators in total\n", 272 | "- 2 discriminators\n", 273 | "\n", 274 | "The two generators GAB’s and the two generators\n", 275 | "GBA’s share parameters, and the generated images\n", 276 | "xBA and xAB are each fed into separate discriminators LDA\n", 277 | "and LDB , respectively.\n" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 17, 283 | "metadata": { 284 | "collapsed": true 285 | }, 286 | "outputs": [], 287 | "source": [ 288 | "\"\"\" Networks \"\"\"\n", 289 | "\n", 290 | "#Each of the two coupled models learns the mapping from\n", 291 | "#one domain to another, and also the reverse mapping to for\n", 292 | "#reconstruction. The two models are trained together simultaneously.\n", 293 | "#The two generators GAB’s and the two generators\n", 294 | "#GBA’s share parameters, and the generated images\n", 295 | "#xBA and xAB are each fed into separate discriminators LDA\n", 296 | "#and LDB , respectively.\n", 297 | "\n", 298 | "\n", 299 | "\n", 300 | "#2 generators\n", 301 | "def generative_network(z, input_dim, n_layer, n_hidden, eps_dim):\n", 302 | " with tf.variable_scope(\"generative\"):\n", 303 | " h = z\n", 304 | " #repeat allow us to repeatedly perform the same operation.\n", 305 | " #many fully connected layers\n", 306 | " h = slim.repeat(h, n_layer, slim.fully_connected, n_hidden, activation_fn=tf.nn.relu)\n", 307 | " x = slim.fully_connected(h, input_dim, activation_fn=None, scope=\"p_x\")\n", 308 | " return x\n", 309 | "\n", 310 | "\n", 311 | "def inference_network(x, latent_dim, n_layer, n_hidden, eps_dim):\n", 312 | " with tf.variable_scope(\"inference\"):\n", 313 | " h = x\n", 314 | " h = slim.repeat(h, n_layer, slim.fully_connected, n_hidden, activation_fn=tf.nn.relu)\n", 315 | " z = slim.fully_connected(h, latent_dim, activation_fn=None, scope=\"q_z\")\n", 316 | " return z\n", 317 | "\n", 318 | "\n", 319 | "#2 discriminators\n", 320 | "def data_network_x(x, n_layers=2, n_hidden=256, activation_fn=None):\n", 321 | " \"\"\"Approximate x log data density.\"\"\"\n", 322 | " h = tf.concat(x, 1)\n", 323 | " with tf.variable_scope('discriminator_x'):\n", 324 | " h = slim.repeat(h, n_layers, slim.fully_connected, n_hidden, activation_fn=tf.nn.relu)\n", 325 | " log_d = slim.fully_connected(h, 1, activation_fn=activation_fn)\n", 326 | " return tf.squeeze(log_d, squeeze_dims=[1]) #Removes dimensions of size 1 \n", 327 | " #from the shape of a tensor.\n", 328 | "\n", 329 | "\n", 330 | "\n", 331 | "def data_network_z(z, n_layers=2, n_hidden=256, activation_fn=None):\n", 332 | " \"\"\"Approximate z log data density.\"\"\"\n", 333 | " h = tf.concat(z, 1)\n", 334 | " with tf.variable_scope('discriminator_z'):\n", 335 | " h = slim.repeat(h, n_layers, slim.fully_connected, n_hidden, activation_fn=tf.nn.relu)\n", 336 | " log_d = slim.fully_connected(h, 1, activation_fn=activation_fn)\n", 337 | " return tf.squeeze(log_d, squeeze_dims=[1])" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 18, 343 | "metadata": { 344 | "collapsed": false 345 | }, 346 | "outputs": [ 347 | { 348 | "name": "stdout", 349 | "output_type": "stream", 350 | "text": [ 351 | "INFO:tensorflow:Copying op: concat\n", 352 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 353 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 354 | "INFO:tensorflow:Copying op: discriminator_x/Repeat/fully_connected_1/MatMul\n", 355 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 356 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 357 | "INFO:tensorflow:Copying op: discriminator_x/Repeat/fully_connected_1/BiasAdd\n", 358 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 359 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 360 | "INFO:tensorflow:Copying op: discriminator_x/Repeat/fully_connected_1/Relu\n", 361 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 362 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 363 | "INFO:tensorflow:Copying op: discriminator_x/Repeat/fully_connected_2/MatMul\n", 364 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 365 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 366 | "INFO:tensorflow:Copying op: discriminator_x/Repeat/fully_connected_2/BiasAdd\n", 367 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 368 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 369 | "INFO:tensorflow:Copying op: discriminator_x/Repeat/fully_connected_2/Relu\n", 370 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 371 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 372 | "INFO:tensorflow:Copying op: discriminator_x/fully_connected/MatMul\n", 373 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 374 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 375 | "INFO:tensorflow:Copying op: discriminator_x/fully_connected/BiasAdd\n", 376 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 377 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 378 | "INFO:tensorflow:Copying op: Squeeze\n", 379 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 380 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 381 | "INFO:tensorflow:Finalizing op: concat\n", 382 | "INFO:tensorflow:Finalizing op: discriminator_x/Repeat/fully_connected_1/MatMul\n", 383 | "INFO:tensorflow:Finalizing op: discriminator_x/Repeat/fully_connected_1/BiasAdd\n", 384 | "INFO:tensorflow:Finalizing op: discriminator_x/Repeat/fully_connected_1/Relu\n", 385 | "INFO:tensorflow:Finalizing op: discriminator_x/Repeat/fully_connected_2/MatMul\n", 386 | "INFO:tensorflow:Finalizing op: discriminator_x/Repeat/fully_connected_2/BiasAdd\n", 387 | "INFO:tensorflow:Finalizing op: discriminator_x/Repeat/fully_connected_2/Relu\n", 388 | "INFO:tensorflow:Finalizing op: discriminator_x/fully_connected/MatMul\n", 389 | "INFO:tensorflow:Finalizing op: discriminator_x/fully_connected/BiasAdd\n", 390 | "INFO:tensorflow:Finalizing op: Squeeze\n", 391 | "INFO:tensorflow:Copying op: concat_2\n", 392 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 393 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 394 | "INFO:tensorflow:Copying op: discriminator_z/Repeat/fully_connected_1/MatMul\n", 395 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 396 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 397 | "INFO:tensorflow:Copying op: discriminator_z/Repeat/fully_connected_1/BiasAdd\n", 398 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 399 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 400 | "INFO:tensorflow:Copying op: discriminator_z/Repeat/fully_connected_1/Relu\n", 401 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 402 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 403 | "INFO:tensorflow:Copying op: discriminator_z/Repeat/fully_connected_2/MatMul\n", 404 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 405 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 406 | "INFO:tensorflow:Copying op: discriminator_z/Repeat/fully_connected_2/BiasAdd\n", 407 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 408 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 409 | "INFO:tensorflow:Copying op: discriminator_z/Repeat/fully_connected_2/Relu\n", 410 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 411 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 412 | "INFO:tensorflow:Copying op: discriminator_z/fully_connected/MatMul\n", 413 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 414 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 415 | "INFO:tensorflow:Copying op: discriminator_z/fully_connected/BiasAdd\n", 416 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 417 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 418 | "INFO:tensorflow:Copying op: Squeeze_2\n", 419 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 420 | "WARNING:tensorflow:VARIABLES collection name is deprecated, please use GLOBAL_VARIABLES instead; VARIABLES will be removed after 2017-03-02.\n", 421 | "INFO:tensorflow:Finalizing op: concat_2\n", 422 | "INFO:tensorflow:Finalizing op: discriminator_z/Repeat/fully_connected_1/MatMul\n", 423 | "INFO:tensorflow:Finalizing op: discriminator_z/Repeat/fully_connected_1/BiasAdd\n", 424 | "INFO:tensorflow:Finalizing op: discriminator_z/Repeat/fully_connected_1/Relu\n", 425 | "INFO:tensorflow:Finalizing op: discriminator_z/Repeat/fully_connected_2/MatMul\n", 426 | "INFO:tensorflow:Finalizing op: discriminator_z/Repeat/fully_connected_2/BiasAdd\n", 427 | "INFO:tensorflow:Finalizing op: discriminator_z/Repeat/fully_connected_2/Relu\n", 428 | "INFO:tensorflow:Finalizing op: discriminator_z/fully_connected/MatMul\n", 429 | "INFO:tensorflow:Finalizing op: discriminator_z/fully_connected/BiasAdd\n", 430 | "INFO:tensorflow:Finalizing op: Squeeze_2\n" 431 | ] 432 | } 433 | ], 434 | "source": [ 435 | "\"\"\" Construct model and training ops \"\"\"\n", 436 | "tf.reset_default_graph()\n", 437 | "\n", 438 | "#data1 input\n", 439 | "x = tf.placeholder(tf.float32, shape=(batch_size, input_dim))\n", 440 | "#data 2 input\n", 441 | "z = tf.placeholder(tf.float32, shape=(batch_size, latent_dim))\n", 442 | "\n", 443 | "# 2 generators - encoders\n", 444 | "p_x = generative_network(z, input_dim , n_layer_gen, n_hidden_gen, eps_dim)\n", 445 | "q_z = inference_network(x, latent_dim, n_layer_inf, n_hidden_inf, eps_dim)\n", 446 | "\n", 447 | "#The logit function is the inverse of the sigmoidal \"logistic\" function\n", 448 | "\n", 449 | "#2 discriminators\n", 450 | "decoder_logit_x = data_network_x(p_x, n_layers=n_layer_disc, n_hidden=n_hidden_disc)\n", 451 | "encoder_logit_x = graph_replace(decoder_logit_x, {p_x: x})\n", 452 | "\n", 453 | "decoder_logit_z = data_network_z(q_z, n_layers=n_layer_disc, n_hidden=n_hidden_disc)\n", 454 | "encoder_logit_z = graph_replace(decoder_logit_z, {q_z: z})\n", 455 | "\n", 456 | "#Computes softplus: log(exp(features) + 1). activation\n", 457 | "#for calculating loss\n", 458 | "encoder_sigmoid_x = tf.nn.softplus(encoder_logit_x)\n", 459 | "decoder_sigmoid_x = tf.nn.softplus(decoder_logit_x)\n", 460 | "encoder_sigmoid_z = tf.nn.softplus(encoder_logit_z)\n", 461 | "decoder_sigmoid_z = tf.nn.softplus(decoder_logit_z)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 20, 467 | "metadata": { 468 | "collapsed": false 469 | }, 470 | "outputs": [], 471 | "source": [ 472 | "#loss functions\n", 473 | "\n", 474 | "#loss for both discriminators\n", 475 | "decoder_loss = decoder_sigmoid_x + decoder_sigmoid_z\n", 476 | "encoder_loss = encoder_sigmoid_x + encoder_sigmoid_z\n", 477 | "\n", 478 | "#combined loss for discriminators\n", 479 | "disc_loss = tf.reduce_mean( encoder_loss ) - tf.reduce_mean( decoder_loss)\n", 480 | "\n", 481 | "#2 more generators (decoders)\n", 482 | "rec_z = inference_network(p_x, latent_dim, n_layer_inf, n_hidden_inf, eps_dim )\n", 483 | "rec_x = generative_network(q_z, input_dim , n_layer_gen, n_hidden_gen, eps_dim )\n", 484 | "\n", 485 | "#compute generator loss\n", 486 | "#Sum of Squared Error loss\n", 487 | "cost_z = tf.reduce_mean(tf.pow(rec_z - z, 2))\n", 488 | "cost_x = tf.reduce_mean(tf.pow(rec_x - x, 2))\n", 489 | "#we tie in discriminator loss into generators loss\n", 490 | "adv_loss = tf.reduce_mean( decoder_loss ) \n", 491 | "gen_loss = 1*adv_loss + 1.*cost_x + 1.*cost_z\n", 492 | "\n", 493 | "#collect vars with names that contain this\n", 494 | "qvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, \"inference\")\n", 495 | "pvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, \"generative\")\n", 496 | "dvars_x = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, \"discriminator_x\")\n", 497 | "dvars_z = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, \"discriminator_z\")\n", 498 | "\n", 499 | "#use adam (gradient descent) to optimize\n", 500 | "opt = tf.train.AdamOptimizer(1e-4, beta1=0.5)\n", 501 | "\n", 502 | "#minimize generators loss\n", 503 | "train_gen_op = opt.minimize(gen_loss, var_list=qvars + pvars)\n", 504 | "\n", 505 | "#minimize discirimaintors loss\n", 506 | "train_disc_op = opt.minimize(disc_loss, var_list=dvars_x + dvars_z)" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": 21, 512 | "metadata": { 513 | "collapsed": false 514 | }, 515 | "outputs": [ 516 | { 517 | "name": "stderr", 518 | "output_type": "stream", 519 | "text": [ 520 | " 19%|█▊ | 187/1000 [00:50<03:22, 4.02it/s]" 521 | ] 522 | }, 523 | { 524 | "ename": "KeyboardInterrupt", 525 | "evalue": "", 526 | "output_type": "error", 527 | "traceback": [ 528 | "\u001b[0;31m--------------------------------------------------------------------------\u001b[0m", 529 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 530 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mf_d\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdisc_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_disc_op\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mxmb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mzmb\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mf_g\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0madv_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcost_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcost_z\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_gen_op\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mxmb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mzmb\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0mFG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_g\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 531 | "\u001b[0;32m/Users/sraval/anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 765\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 766\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 767\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 768\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 769\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 532 | "\u001b[0;32m/Users/sraval/anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 963\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 964\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m--> 965\u001b[0;31m feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[1;32m 966\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 967\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 533 | "\u001b[0;32m/Users/sraval/anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1013\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1014\u001b[0m return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[0;32m-> 1015\u001b[0;31m target_list, options, run_metadata)\n\u001b[0m\u001b[1;32m 1016\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1017\u001b[0m return self._do_call(_prun_fn, self._session, handle, feed_dict,\n", 534 | "\u001b[0;32m/Users/sraval/anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 1020\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1021\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1022\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1023\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1024\u001b[0m \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 535 | "\u001b[0;32m/Users/sraval/anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 1002\u001b[0m return tf_session.TF_Run(session, options,\n\u001b[1;32m 1003\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1004\u001b[0;31m status, run_metadata)\n\u001b[0m\u001b[1;32m 1005\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1006\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msession\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 536 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 537 | ] 538 | } 539 | ], 540 | "source": [ 541 | "\"\"\" training \"\"\"\n", 542 | "sess = tf.InteractiveSession()\n", 543 | "sess.run(tf.global_variables_initializer())\n", 544 | "\n", 545 | "FG = []\n", 546 | "FD = []\n", 547 | "\n", 548 | "#for each epoch (log the status bar)\n", 549 | "for epoch in tqdm( range(n_epoch), total=n_epoch):\n", 550 | " #sample from both our datasets\n", 551 | " X_dataset, Z_dataset= shuffle(X_dataset, Z_dataset)\n", 552 | "\n", 553 | " #for each x and z in our data \n", 554 | " for xmb, zmb in iter_data(X_dataset, Z_dataset, size=batch_size):\n", 555 | " \n", 556 | " #minimize our loss functions\n", 557 | " for _ in range(1):\n", 558 | " f_d, _ = sess.run([disc_loss, train_disc_op], feed_dict={x: xmb, z:zmb})\n", 559 | " for _ in range(5):\n", 560 | " #3 components that make up generator loss\n", 561 | " f_g, _ = sess.run([[adv_loss, cost_x, cost_z], train_gen_op], feed_dict={x: xmb, z:zmb})\n", 562 | "\n", 563 | " FG.append(f_g)\n", 564 | " FD.append(f_d)\n" 565 | ] 566 | }, 567 | { 568 | "cell_type": "code", 569 | "execution_count": null, 570 | "metadata": { 571 | "collapsed": true 572 | }, 573 | "outputs": [], 574 | "source": [ 575 | "\"\"\" plot the results \"\"\"\n", 576 | "\n", 577 | "n_viz = 1\n", 578 | "imz = np.array([]); rmz = np.array([]); imx = np.array([]); rmx = np.array([]);\n", 579 | "for _ in range(n_viz):\n", 580 | " for xmb, zmb in iter_data(X_np_data, Z_np_data, size=batch_size):\n", 581 | " temp_imz = sess.run(q_z, feed_dict={x: xmb, z:zmb})\n", 582 | " imz = np.vstack([imz, temp_imz]) if imz.size else temp_imz\n", 583 | "\n", 584 | " temp_rmz = sess.run(rec_z, feed_dict={x: xmb, z:zmb})\n", 585 | " rmz = np.vstack([rmz, temp_rmz]) if rmz.size else temp_rmz\n", 586 | "\n", 587 | " temp_imx = sess.run(p_x, feed_dict={x: xmb, z:zmb})\n", 588 | " imx = np.vstack([imx, temp_imx]) if imx.size else temp_imx\n", 589 | "\n", 590 | " temp_rmx = sess.run(rec_x, feed_dict={x: xmb, z:zmb})\n", 591 | " rmx = np.vstack([rmx, temp_rmx]) if rmx.size else temp_rmx\n", 592 | "\n", 593 | "## inferred marginal z\n", 594 | "fig_mz, ax = plt.subplots(nrows=1, ncols=1, figsize=(4.5, 4.5))\n", 595 | "ll = np.tile(X_labels, (n_viz))\n", 596 | "ax.scatter(imz[:, 0], imz[:, 1], c=cm.Set1(ll.astype(float)/input_dim/2.0),\n", 597 | " edgecolor='none', alpha=0.5)\n", 598 | "ax.set_xlim(-3, 3); ax.set_ylim(-3.5, 3.5)\n", 599 | "ax.set_xlabel('$z_1$'); ax.set_ylabel('$z_2$')\n", 600 | "ax.axis('on')\n", 601 | "plt.savefig(result_dir + 'inferred_mz.pdf', transparent=True, bbox_inches='tight')\n", 602 | "\n", 603 | "## reconstruced z\n", 604 | "fig_pz, ax = plt.subplots(nrows=1, ncols=1, figsize=(4.5, 4.5))\n", 605 | "ll = np.tile(Z_labels, (n_viz))\n", 606 | "ax.scatter(rmz[:, 0], rmz[:, 1], c=cm.Set1(ll.astype(float)/input_dim/2.0),\n", 607 | " edgecolor='none', alpha=0.5)\n", 608 | "ax.set_xlim(-3, 3); ax.set_ylim(-3.5, 3.5)\n", 609 | "ax.set_xlabel('$z_1$'); ax.set_ylabel('$z_2$')\n", 610 | "ax.axis('on')\n", 611 | "plt.savefig(result_dir + 'reconstruct_mz.pdf', transparent=True, bbox_inches='tight')\n", 612 | "\n", 613 | "## inferred marginal x\n", 614 | "fig_pz, ax = plt.subplots(nrows=1, ncols=1, figsize=(4.5, 4.5))\n", 615 | "ll = np.tile(Z_labels, (n_viz))\n", 616 | "ax.scatter(imx[:, 0], imx[:, 1], c=cm.Set1(ll.astype(float)/input_dim/2.0),\n", 617 | " edgecolor='none', alpha=0.5)\n", 618 | "ax.set_xlim(-3, 3); ax.set_ylim(-3.5, 3.5)\n", 619 | "ax.set_xlabel('$x_1$'); ax.set_ylabel('$x_2$')\n", 620 | "ax.axis('on')\n", 621 | "plt.savefig(result_dir + 'inferred_mx.pdf', transparent=True, bbox_inches='tight')\n", 622 | "\n", 623 | "## reconstruced x\n", 624 | "fig_mx, ax = plt.subplots(nrows=1, ncols=1, figsize=(4.5, 4.5))\n", 625 | "ll = np.tile(X_labels, (n_viz))\n", 626 | "ax.scatter(rmx[:, 0], rmx[:, 1], c=cm.Set1(ll.astype(float)/input_dim/2.0),\n", 627 | " edgecolor='none', alpha=0.5)\n", 628 | "ax.set_xlim(-3, 3); ax.set_ylim(-3.5, 3.5)\n", 629 | "ax.set_xlabel('$x_1$'); ax.set_ylabel('$x_2$')\n", 630 | "ax.axis('on')\n", 631 | "plt.savefig(result_dir + 'reconstruct_mx.pdf', transparent=True, bbox_inches='tight')\n", 632 | "\n", 633 | "## learning curves\n", 634 | "fig_curve, ax = plt.subplots(nrows=1, ncols=1, figsize=(4.5, 4.5))\n", 635 | "ax.plot(FD, label=\"Discriminator\")\n", 636 | "ax.plot(np.array(FG)[:,0], label=\"Generator\")\n", 637 | "ax.plot(np.array(FG)[:,1], label=\"Reconstruction x\")\n", 638 | "ax.plot(np.array(FG)[:,2], label=\"Reconstruction Z\")\n", 639 | "plt.xlabel('Iteration')\n", 640 | "plt.xlabel('Loss')\n", 641 | "ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)\n", 642 | "ax.axis('on')\n", 643 | "plt.savefig(result_dir + 'learning_curves.pdf', bbox_inches='tight')\n" 644 | ] 645 | } 646 | ], 647 | "metadata": { 648 | "kernelspec": { 649 | "display_name": "Python 3", 650 | "language": "python", 651 | "name": "python3" 652 | }, 653 | "language_info": { 654 | "codemirror_mode": { 655 | "name": "ipython", 656 | "version": 3 657 | }, 658 | "file_extension": ".py", 659 | "mimetype": "text/x-python", 660 | "name": "python", 661 | "nbconvert_exporter": "python", 662 | "pygments_lexer": "ipython3", 663 | "version": "3.6.0" 664 | } 665 | }, 666 | "nbformat": 4, 667 | "nbformat_minor": 2 668 | } 669 | --------------------------------------------------------------------------------