├── code ├── 2d_latent_space_for_test_samples.png ├── 2d_latent_space_scan_for_generation.png ├── generated_samples_with_2D_latent_space.png ├── generated_samples_with_10D_latent_space.png ├── generated_samples_with_20D_latent_space.png └── vaecnn-gluon.ipynb ├── README.md └── .gitignore /code/2d_latent_space_for_test_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingran/vae-mxnet/HEAD/code/2d_latent_space_for_test_samples.png -------------------------------------------------------------------------------- /code/2d_latent_space_scan_for_generation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingran/vae-mxnet/HEAD/code/2d_latent_space_scan_for_generation.png -------------------------------------------------------------------------------- /code/generated_samples_with_2D_latent_space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingran/vae-mxnet/HEAD/code/generated_samples_with_2D_latent_space.png -------------------------------------------------------------------------------- /code/generated_samples_with_10D_latent_space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingran/vae-mxnet/HEAD/code/generated_samples_with_10D_latent_space.png -------------------------------------------------------------------------------- /code/generated_samples_with_20D_latent_space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingran/vae-mxnet/HEAD/code/generated_samples_with_20D_latent_space.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational autoencoders in MXNet/Gluon 2 | 3 | Implementations of variational autoencoders using MXNet/Gluon. 4 | 5 | ### References: 6 | 7 | 1. Kingma, Diederik P., and Max Welling. 8 | ["Auto-encoding variational bayes."](https://arxiv.org/abs/1312.6114) 9 | arXiv preprint arXiv:1312.6114 (2013). 10 | 2. Rezende, Danilo Jimenez, Shakir Mohamed, and Daan Wierstra. 11 | ["Stochastic backpropagation and approximate inference in deep generative models."](https://arxiv.org/abs/1401.4082) 12 | arXiv preprint arXiv:1401.4082 (2014). 13 | 14 | ### Code: 15 | 16 | - Implementation using MXNet API (i.e. mxnet.sym, mxnet.mod) [vae-mxnet.ipynb](code/vae-mxnet.ipynb) 17 | - Implementation using Gluon API (i.e. gluon.HybridBlock, autograd) [vae-gluon.ipynb](code/vae-gluon.ipynb) 18 | - CNN-based version, implemented using Gluon [vaecnn-gluon.ipynb](code/vaecnn-gluon.ipynb) 19 | 20 | ### Results: 21 | 22 | - Generated MNIST figures by randomly sampling learned latent space 23 | 24 | With 2-D latent space | With 10-D latent space | With 20-D latent space 25 | --- | --- | --- 26 | ![](code/generated_samples_with_2D_latent_space.png) | ![](code/generated_samples_with_10D_latent_space.png) |![](code/generated_samples_with_20D_latent_space.png) 27 | 28 | - Learned 2-D manifold 29 | 30 | Latent feature Z corresponding to 1000 test images | Generated images from grid scan in Z | 31 | --- | --- 32 | ![](code/2d_latent_space_for_test_samples.png) | ![](code/2d_latent_space_scan_for_generation.png) 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *,cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | archive 98 | 99 | *.gz 100 | .idea/ 101 | .ipynb_checkpoints/ 102 | 103 | *.params 104 | code/*.json 105 | -------------------------------------------------------------------------------- /code/vaecnn-gluon.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Implementation of Variational Autoencoder in MXNet/Gluon\n", 8 | "\n", 9 | "This is the implementation using the new Gluon API, i.e. gluon.HybridBlock and autograd.\n", 10 | "\n", 11 | "Instead of using fully connected layers we used CNN (similar to the one used in DCGAN) in this implementation to parameterize recognition and generation network.\n", 12 | "\n", 13 | "Ref paper: Kingma, Diederik P., and Max Welling. [\"Auto-encoding variational bayes.\"](https://arxiv.org/abs/1312.6114) arXiv preprint arXiv:1312.6114 (2013)." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import time\n", 23 | "import numpy as np\n", 24 | "import mxnet as mx\n", 25 | "from tqdm import tqdm, tqdm_notebook\n", 26 | "from mxnet import nd, autograd, gluon\n", 27 | "from mxnet.gluon import nn\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "%matplotlib inline\n", 30 | "data_ctx = mx.cpu()\n", 31 | "model_ctx = mx.gpu(0)\n", 32 | "mx.random.seed(1)\n", 33 | "output_fig = False" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "# Load MNIST" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 2, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "data": { 50 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA5IAAABhCAYAAACkn544AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAGjBJREFUeJzt3Xm8VHUZx/Ev5ZJohIgCaoJCQoTkgoYsWSoYYJkWYILhiorimiDikhqKK2YlkAthghtLYqVRpoGypkKuSCFICsgqECiItz98/Z55jnPunXvucu4sn/dfz+s3Z+b+OMycOWd+z3meemVlZQIAAAAAoLK+UNcTAAAAAAAUFi4kAQAAAACJcCEJAAAAAEiEC0kAAAAAQCJcSAIAAAAAEuFCEgAAAACQCBeSAAAAAIBEuJAEAAAAACTChSQAAAAAIBEuJAEAAAAAieyUZOPGjRuXtWjRopamUlyWLl2qNWvW1KvKc9nPybz00ktrysrK9q7Kc9nXybCv08O+Tg/7Oh18L6aH93R62NfpYV+nI8mxOtGFZIsWLfTPf/6zarMqMR06dKjyc9nPydSrV29ZVZ/Lvk6GfZ0e9nV62Nfp4HsxPbyn08O+Tg/7Oh1JjtWktgIAAAAAEuFCEgAAAACQCBeSAAAAAIBEuJAEAAAAACTChSQAAAAAIJFEVVsL1ebNmy3u2rWrxQsWLJAkLV682MZatWqV3sQAAKimsrIySVL//v1tbM6cORa/9NJLkqSGDRumOzEAQFFjRRIAAAAAkAgXkgAAAACARIoutTWk+Dz33HM2NmTIEIsXLlxocb169dKbGJBDeO9K0ooVKyyeMmWKJOmTTz6xsUGDBlm8yy67pDA7IDd/m8BVV10lSZo9e7aNDR482OJhw4alN7Ei5G/ZuP766yVJjzzySOy277//viRSWwEANYsVSQAAAABAIlxIAgAAAAASKbrU1hdffFGSdPzxx8c+/qUvfcniXr16SZKaNWtW+xMDyhHSAW+44QYbi0tR23PPPS3u16+fxXvvvXctzg7ICOnXixYtsrEbb7zR4qlTp1q8bdu2rOeHFEwpU0l00qRJNT7PUrB27VqLR40aVeG24f/o0UcfrdU5ATVl06ZNkqT58+fb2NChQy0Ox4+rr77axk499VSL27RpY/FOOxXdqS6QN1iRBAAAAAAkUhQ/08ydO9diX8whzimnnGLxww8/XGtzAj7Pr9Ccd955Fk+YMEGStGPHDhvzK+e/+MUvJEXfu6xCIi2+CFQoXHbXXXflfF54D/siUT4Oq5erVq2ysSZNmlRvskVu3bp1Fvfs2bPSz5s3b54kacOGDTZG4Z3ktm7dKilTvOjzwmdl+/btNva73/2uwtds2rSpxRdddJHFO++8c1WnWVD8vvK9Ty+55BJJmX7fUrRAYohvueUWG7v55pst9plmM2bMkCS1bNmypqadlz7++GOLv/e970mSnn/+eRu75pprLPbZIWFffvGLX6zlGSJYsmSJxQcddJDFs2bNkhR9r3fu3NliPx6ON3FjfjxuTJKaN29u8bPPPitJOvDAA5P+U1iRBAAAAAAkw4UkAAAAACCRgk1tDTdiS9F0kJAG4Zdvu3fvbvGYMWNSmB3wmZUrV1rs00oeeughi0M638knn2xjvvepTz9A+XyBoh/96EcW+9T39957L+t5cWkf//73v23MpwB5n376qSTpC1/I/B734IMPWjxgwIBKzz2fjR492uJcKa3+PX7llVdKiqamHXPMMVnPOfzwwy32+3233XZLPtkid9NNN1n85ptvVvp5S5culSRddtllNjZixAiL99133+pPrkCFz/FHH31kY0888YTF/haY8P5ctmxZ7GuFY4lPEcx1G8L+++9vsb/loRhTW0Ma6+OPP25j9957r8U+tbW6/HdvSKP/2c9+VmOvn4923XVXi8MtMT169LAxn/rr4wMOOEBS9FzaO/bYYyPbSdIee+wRu23oa11KfdrDucfMmTNtLFe66TPPPGNjIQ1Zylyj+PMK/1p+PO4cJIz58bgxKZPOKlUtpdVes8rPBAAAAACUJC4kAQAAAACJFGxq60knnWRx6Cfk+aqXd955p8XlLcej7rzzzjuSoilAhfj/5NMH/vOf/0iSunTpYmNr1qyx+LTTTrM4pJh89atfjX3dUO21f//+NuZTKOL43lpnn322xfXr16/weflo48aNkqTTTz/dxl577bWs7UL6niQNGzbM4tWrV1vsq9oF5VU0q2hMyqSI+Md9T8RiSW31x884/nPrU/PCZ7hTp042NmjQIItDStuKFStszFdwJLX1M1u2bLHYp0NVxfjx4y0OlSwlafr06Ra3aNFCUulUcAzp2v52glx8H9+QyidlKrD27dvXxtq3b1/dKRaNcFz2KfK5UiB9OuW5555r8XXXXVfDsysuodLnBx98YGNPPfWUxcuXL7f4tttukyTdfvvtNrZ+/XqLfbXtXP70pz9JiqZrFqORI0daPHz4cEnln0vkqqTqb7kLtzL5lOTyXvc3v/lN1f8BNYgVSQAAAABAIgW3Ihl6Mr3wwguxj3/729+WJD322GM25vs0oWaEfmbTpk2zsddffz1223BD77vvvhv7+Icffigp+mujvwG/UCxatMjidu3aZT1+1llnWez/fbmKKjz33HOSpMmTJ1d6LpdeeqnFvs9WKIbiV4byXXhfhV86K+O///1vbU0nyz777GOxL4ZSLEaNGmVxWHH0K7vdunWz2N/IHzfmV4rjPuNPP/20xX5VpxSFvrK+gJM/xlRXyASRpK997WsW33HHHZIyffyk4ludXLhwocWhmNahhx5qY75AVJs2bSwO2U8+46mUiopUVyja4rMc/Iqt/8z36dNHUrT3Y1g5k6KrNHEaNGhgsc/KKTX+vdq7d+/YbS6//PKssfnz51vss0aC0O9Qiq5k/uMf/5BU/CuSYRVSyl3UpioFcKpT/CZtrEgCAAAAABLhQhIAAAAAkEhBpLbec889FofeZOXd/Pvb3/5WEums1RV6MPm+cA888IDFSdJ5QgqKL0rgl/W/8pWvSCqspfwgpJ9J0rXXXpv1eEhLlaSjjz7a4iQ9wr773e9mvdbWrVst/vKXv2xx6Gfk0+FWrVpl8cUXXyxJOuSQQ2ysa9eulZ5LXfjGN74hKZrO7tMiJ06cKClTJETKnYrnC/eEdHkp0+dw4MCBNhb6cUnxxU5OOeUUi316XLH4wQ9+YPGJJ54oKVoAJklhrFzb+gIGpZ7aGj7j4TNbkVBEyx9jN2zYUKW/G3rthf62UrS4TDE4+OCDLQ7pwj4F0KerP/nkkxZTAKp6evbsKUlasmSJje23334W+/dvHH/eEXcO4tNZ/ffFnnvumXyyJe7II4+s8HGfjulTW30P52Lmz2EDn27tH/dF6MI1TCGe75aHFUkAAAAAQCJcSAIAAAAAEsnb1FafluPT9HyfsSCk/knFtVxcl0Il1XHjxtlYXFpJo0aNbOz73/++xb4KYEh18Gkn/v+xvP6JhcD3Yps6darFIS3sqKOOsrEk6axeSPcJFYkrEnr2hfQJSerYsaPFIc3VP+7/DblSi+pCeN/4f8dhhx1mca9evSRVPRUyLiXZ833MSl1IZyrEPq/FbPTo0ZKiFUZ9Zea4Xsu5/O1vf7P41FNPtbgYKrj6FNX9999fUialV5K+/vWvW1zI30/5Ksl5mq/APXbs2Aq39ZVE/f8haobv33zGGWdYHG5PkqTWrVunOaU6c+utt1ocKriWV4m1ELsQJMGKJAAAAAAgES4kAQAAAACJ5G1q6znnnGPxv/71rwq39c1t8zE1r1D4dNPQrNpXO/PVG0NqpE/7KaV0t1Cdy1ey9aZMmSIpWgkwTT4d06fchtRX32zYp11ceumlKcyu+nbddVeLa6O656ZNmyz2x5+4qmy5GmMDVbF69eoKH/e3D4TKwbvvvruNTZgwweKXX35ZUjT1z1cr3rx5c9brjx8/3uLjjz/e4mKr4Bqq4z711FM25o8puY7hH330kcVxlRzDvpei1UoDX7XUV/YOx7hiSCWuLH8Ocv/990uSLrzwQhuLq9R69dVXW0w6a+34+OOPJUXTWb3nn3/eYl9FvpgNGTLE4qVLl0qSxowZY2NxxwJv9uzZseOdO3eWFH2vL1682OKDDjoo8VxrGyuSAAAAAIBE8mpF0l/B+19I437xnzNnjsXF2LutLkyePNni8Ktz6LElSa1atUp9Tvlq0qRJkjJ9GyWpWbNmFofeh/ngW9/6lsVt27aVJL3xxhs2tn79+tTnlK82btwoSfrxj39sY/5Y438lDMedESNGpDQ7FDvfl/aGG26ocFvf49evRAY+KyHEvmiOf48PHTrU4rlz52a9VlzhnWJZKQvnGv5Xf7+ffM/qBQsWSIoWAPTfm7lWkZMIq8Dt2rWzMX+sKeSelr4P8mOPPWaxLwK3du3arOfFnQv6feLjV155xeJvfvObVZ8srIe136e+v3L79u1Tn1M+CecF/vzAF9vxhc/CNn710m8bHvdj3bp1s9gfi/OluCgrkgAAAACARLiQBAAAAAAkklepra+++qrF06dPt9gvF3fv3l2SdPjhh1f6df3N8D5dIhRK8Tdoh6IFUvGk7lQk9IuUpJ/+9KcWL1y4UJLUsmXLCp/vU3n8axV7Gqx/rwZXXHGFxfl6w/mZZ54pKZpChIzXX39dkvT3v/8957ahwIDvoYXKC2lqPl2t1AsXbdu2zeKHHnqowm1zFV3IVSTG96X17/eQlh8KSEjRwjuhOFchp1Z6PsUs8L2pvfB95x+/4447LA6Fc3zRnCSWLVtmcSiI5vvc+tt/7r777ir9jXzg0yJHjhwZu0047/PHBH8u2LhxY0nSmjVrbMxv63s433fffZKi53e+OGC41SPc+oFo/87TTz9dUrSX8wUXXJD6nPJV3HeZ/6z6Y0zc+9pve95550W2kzJ9gqXoubW/FaIusSIJAAAAAEiEC0kAAAAAQCJ5kdoaqqJdf/31ObcNKSU777xz7ONhufjJJ5+0MZ8aElLXyuOrWbZp0ybnfAqRXw73++bYY4+1uFGjRpKiKU8zZsywOKS+/vWvf7Wx1q1bW+z7FBZjinBcb9OVK1fWwUySadq0adYYFe0yxo4dW+ltSWn9zJYtWyz2x1df7Tbwx4W4nnChX5kUTVkLqeK+f2gx8hUs4/h01iOOOKLG/q5PU/35z38uqfyeccXAV2j1aZbBTjtlTo18CmnY582bN7exmvx+8+ccIV61apWN+d51hZza6qu2er5S5QknnCBJOuaYY2zMHz/Cv9/3U/Wv66vthgq45fW6JqX1M/42sHAckKR169ZJih6fGjZsmNq88l24VcifP/j3sk9dDeMzZ860Mf9d6KvsB+VVeA3VYP2tSnVRyZUVSQAAAABAInmxIhlWcqZNm5ZzW9/7KvA9Jy+55BJJ0rhx46o0l3vuucfiUFSg2PgVydAPUYr+8nnwwQdLkr7zne/YWCjUImV6nP3xj3+0Md97K/yCJUl77713Dcw6v8StjPTp06cOZlJ9jz/+uMW+GEEp8r8SxvGfgVIUMj6eeeYZG/NFF5YvX16t13/77bctbtKkicWhT5l/f/r+waG4yV577WVjcSue+e69996r8HF/PM5VTKeqwnHMZ6v4/9eJEydKks4+++xa+fu1xX/vheIrkrR9+3ZJ0T7Auf4f0hCKnVx33XU25udYyHz2k/+cHnnkkRb784mKnHTSSRY/+uijFodiPFL5K5GI8lmB/hx68ODBkqR99tkn9TkVgrAKmKuAjpRZPUyycvjCCy9Y3KVLF4vDSqUvlMmKJAAAAAAg73EhCQAAAABIJC9SW3MVGPDCTdOdOnWyMd/TL8RJ0pp8T6jhw4dX+nmFyt+s63uJHXbYYRb36tVLktSuXbsKX8unsB533HEWh2I9xSqkg/p9GdLv8tmGDRuyxnz/z1J01llnWex75wW+qE6p999ctGiRJOnEE09M9e+G4lZxRa4839Pvsssuq9U51YX+/fvX+t8IKbPlFbT74IMPan0OteG1116z2L9PQlGbefPmpT6nz/PFTgYOHCgpmk44a9as1OdUG/zxo6rHkrCvXnzxxdjHfRGeBg0aVOlvlIqwL2fPnm1j/nzutttukyTtsssu6U6swPjPZ64COkmE3r6S1KNHD4uffvppSdGCYf52s7SwIgkAAAAASIQLSQAAAABAInmR2hoqDi1btszGfv3rX8duG7bx21aXT+/cb7/9aux10+B7/YS+SSElRpIOOOCArOf4Hlm+ylkSoY/k+eefb2O+kmMx9o70CqkipP+s3HrrrVmP/+QnP0lzOnln/PjxFsf9vz7yyCMWh2rGpcT3YzvnnHOyHq9fv77Fo0aNsjj0gTv55JNt7JVXXsl6vq8U6Pvo/eEPf7A4VNGcO3du7LyCq666ymKfzlZoFUbr0q9+9StJNfsdW5dCOrTvR+gNGzZMUt1V9gzVWaXod3foyfrEE0/Y2L777pvexPLcW2+9JSm6/7xSvw0hF59GHaqK+jRhX4WelNbK6dixY628rv8u8+f0oUpsSHGtK6xIAgAAAAAS4UISAAAAAJBIXqS2hob1vpKaV16aa1X4aqIh9aGQq/v98Ic/tLhnz56SMhW2/JiUqfjXqlUrG/NNvXOZMWOGxaGq64UXXmhjXbt2rfRrFbrddttNUjQ9ZMqUKRb37ds39TlJmfn8/ve/tzH//g6pm71797axfv36pTS7/BFXvbY8vmprKfL7ylf1C0aOHGmxT319+eWXJWXS4D+vQ4cOkqKprV7cZ2jlypUWL1682OJQLc/fpnDIIYfEvm6hadGihcWdO3eusdddv369xXfffbfFN998syRpx44dNuYbbdfkHNIwYcIESdLGjRttrCqN72taSF31n5nt27dbHNK4mzZtmu7E8tj//vc/i88999ysx0PFYUnq1q1bKnMqVKtXr7b44YcflhR9r914442pzwm5+dtvQuy7B9QFViQBAAAAAInkxYpk4G/o9auT/iZ5v5JSEV8AYtCgQRYPGDDAYt+bpVD5FcX3339fkjR//nwbC/0OJWnq1KmSokVx/K8bW7dutTj8Mupv8u3evbvFY8eOlSSddtpp1fsHFKgrrrhCkjRixAgbCzesS9Ly5cslRVcDq1uAyP+COGfOHIt936Bw03UoTiJFVxzD6lGpFm0I73HfOzaO/zW7bdu2tTqnQte4cWOL/bEnFNn59NNPbWz33Xe3+P7770/8t/wv5j4u5myITZs2WexXEX1/wTh+Be7Pf/6zpOjK4xtvvGHx5s2bK3wt/z3hV/MKQVht8fx5RHn9MmuD/7648847JUmtW7e2Mb/in+a80hLec1U9pvrzwpDx4IXvZSl6DohscZl+kyZNstj34UTd8lk/Y8aMsThkivjv2LrAiiQAAAAAIBEuJAEAAAAAieRVaqvn01yTpOHdddddkqJphqWW4uBTj5KkIfmiGuGm9kLrq5mWiy66SJL0zjvv2NjEiRMtHjp0qCTpgQcesDFf/MMXSYrji4qMGzdOkrRixQob82nIe+21V9a8QmElKZp6WOr9oP7yl79Ikl599dUKtxs9erTFPr27FPkUu1B46MMPP7SxM844I/Z527ZtyxrzqWfFUgyntq1du9ZiX3jH9wOO44vl+ONFVYRei1K0oEkhCOlfvkdpXKGWmrRkyRKLfZr8qlWrLL7gggskZb4rpOJMZ/XFhB588EFJ0UJavk+vF1K6fVrfLbfckrWd76t30003VW+yRc7fHnPvvfda3KxZM0lSy5YtU58T4vnPxfDhwy32hXVCSivFdgAAAAAABYULSQAAAABAInmb2pqE75V48cUXS6p+hcxS1LBhw9gY2ULFxF/+8pc25qtIhhTrt99+28Z8PHny5KzX9GnEvuJily5dJEVTCM8//3yLfSVMZPNVL0PVP98Xz1c8CxWefQphqfN9NENK9vjx420sLoVVyrwvfTrrNddcUxtTLBm+b21t85Vwr732Wot9BddCEI61ofevVLO9YT/55BOLQyVW34PvqKOOsthXUT/iiCNqbA757K233rI4vHfCd5okLVq0yOJp06ZZHFL7FixYkPV8SWrfvr2k+N62iOer3m7ZssXiIUOGSMpdCRq1L6Sp+vd6eecrzZs3lyQ9++yzKc0uHiuSAAAAAIBECmJF8uijj7a4rvulAF6jRo0s9oUAfP/Iytpjjz0sDn08pWgxHSQXilpImf6b/tc+f6N6qawSVNV9990nKboK4FfaO3ToYHHoE0lRncrzRRV8jKrz5w81xRdDu/zyyy0OKwOjRo2ysYEDB9b43y90gwcPtjjJCndYhZQy/bALrfhTXQpFFKVMgR2J92hN8T3ry1tRvPLKKyVJt99+e+y2IY4rqvP58XC8OfDAA6s99+pgRRIAAAAAkAgXkgAAAACARAoitRUoBL6vW5Lep6hd69atq/S2vXv3rsWZFL5QxOzNN9+s45kAdcens/rUynnz5knKFMFAtFBT6O/97rvv5nxenz59JEn9+vWzsR49elhMQcXK8d9/4dYESercubPFTZo0SXVOxWrMmDEWl5faOnbs2KyxuG1POOEEG/O9UkNqrFT3Ka0BK5IAAAAAgES4kAQAAAAAJEJqK4CSFnrPSlQABJBb6BcpRatfIptP0Vu6dGndTaRE+e+0Qw891OLjjjuuLqZT1GbNmmWxT1ft1KmTxaHqanmVWGfOnClJatu2rY01aNCg5idbg1iRBAAAAAAkwoUkAAAAACARUlsBFLW+fftaPH36dEnSmWeeaWO+eTgA5EI6KwpF/fr1LZ47d24dzqT4dezYMXZ8x44dKc8kXaxIAgAAAAASYUUSQFEbMGBAbAwAAICqY0USAAAAAJAIF5IAAAAAgETqlZWVVX7jevVWS1pWe9MpKs3Lysr2rsoT2c+Jsa/Tw75OD/s6PezrdLCf08O+Tg/7Oj3s63RUej8nupAEAAAAAIDUVgAAAABAIlxIAgAAAAAS4UISAAAAAJAIF5IAAAAAgES4kAQAAAAAJMKFJAAAAAAgES4kAQAAAACJcCEJAAAAAEiEC0kAAAAAQCL/B0iRuXJnceFTAAAAAElFTkSuQmCC\n", 51 | "text/plain": [ 52 | "" 53 | ] 54 | }, 55 | "metadata": {}, 56 | "output_type": "display_data" 57 | } 58 | ], 59 | "source": [ 60 | "mnist = mx.test_utils.get_mnist()\n", 61 | "#print(mnist['train_data'][0].shape)\n", 62 | "#plt.imshow(mnist['train_data'][0][0],cmap='Greys')\n", 63 | "\n", 64 | "n_samples = 10\n", 65 | "idx = np.random.choice(len(mnist['train_data']), n_samples)\n", 66 | "_, axarr = plt.subplots(1, n_samples, figsize=(16,4))\n", 67 | "for i,j in enumerate(idx):\n", 68 | " axarr[i].imshow(mnist['train_data'][j][0], cmap='Greys')\n", 69 | " #axarr[i].axis('off')\n", 70 | " axarr[i].get_xaxis().set_ticks([])\n", 71 | " axarr[i].get_yaxis().set_ticks([])\n", 72 | "plt.show()" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "#train_data = np.reshape(mnist['train_data'],(-1,28*28))\n", 82 | "#test_data = np.reshape(mnist['test_data'],(-1,28*28))\n", 83 | "train_data = mnist['train_data']\n", 84 | "test_data = mnist['test_data']" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 4, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "data": { 94 | "text/plain": [ 95 | "(10000, 1, 28, 28)" 96 | ] 97 | }, 98 | "execution_count": 4, 99 | "metadata": {}, 100 | "output_type": "execute_result" 101 | } 102 | ], 103 | "source": [ 104 | "mnist['test_data'].shape" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 5, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "batch_size = 100\n", 114 | "n_batches = train_data.shape[0]/batch_size\n", 115 | "train_iter = mx.io.NDArrayIter(data={'data': train_data}, label={'label': mnist['train_label']}, batch_size = batch_size)\n", 116 | "test_iter = mx.io.NDArrayIter(data={'data': test_data}, label={'label': mnist['test_label']}, batch_size = batch_size)\n", 117 | "#train_iter = mx.io.NDArrayIter(data={'data': train_data}, batch_size = batch_size)\n", 118 | "#test_iter = mx.io.NDArrayIter(data={'data': test_data}, batch_size = batch_size)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "# Define model" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 6, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "class Reshape(gluon.HybridBlock):\n", 135 | " def __init__(self, target_shape, **kwargs):\n", 136 | " super().__init__(**kwargs)\n", 137 | " self.target_shape = target_shape\n", 138 | "\n", 139 | " def hybrid_forward(self, F, x):\n", 140 | " #print(x.shape)\n", 141 | " return x.reshape((0, *self.target_shape)) # setting the first axis to 0 to copy over the original shape, i.e. batch_size\n", 142 | "\n", 143 | " def __repr__(self):\n", 144 | " return self.__class__.__name__\n", 145 | "\n", 146 | "\n", 147 | "class VAE(gluon.HybridBlock):\n", 148 | " def __init__(self, n_ch=np.array([4,8,16])*4, ks=[4,4,4], stride=[2,2,2], pad=[3,1,1], \n", 149 | " n_latent=3, n_layers=1, batch_size=100, act_type='relu', use_bias=True, \n", 150 | " do_batch_norm = False, **kwargs):\n", 151 | " self.soft_zero = 1e-10\n", 152 | " self.n_latent = n_latent\n", 153 | " self.batch_size = batch_size\n", 154 | " self.output = None\n", 155 | " self.mu = None\n", 156 | " # note to self: requring batch_size in model definition is sad, not sure how to deal with this otherwise though\n", 157 | " super().__init__(**kwargs)\n", 158 | "\n", 159 | " img_size =[28]\n", 160 | " with self.name_scope():\n", 161 | " self.encoder = nn.HybridSequential(prefix='encoder')\n", 162 | " for i in range(n_layers):\n", 163 | " self.encoder.add(nn.Conv2D(channels=n_ch[i], kernel_size=ks[i], strides=stride[i], padding=pad[i], \n", 164 | " activation=act_type, use_bias=use_bias))\n", 165 | " old_size = img_size[-1] + 2*pad[i]\n", 166 | " new_size = (old_size - (ks[i] - stride[i]))// stride[i]\n", 167 | " #print(new_size)\n", 168 | " img_size.append(new_size)\n", 169 | " if i < n_layers-1 and do_batch_norm:\n", 170 | " self.encoder.add(nn.BatchNorm())\n", 171 | "\n", 172 | " self.encoder.add(nn.Flatten())\n", 173 | " #self.encoder.add(nn.Dense(100, activation=act_type)) \n", 174 | " self.encoder.add(nn.Dense(n_latent*2, activation=None)) \n", 175 | " \n", 176 | " \n", 177 | " self.decoder = nn.HybridSequential(prefix='decoder')\n", 178 | " #self.decoder.add(nn.Dense(100, activation=act_type)) \n", 179 | " self.decoder.add(nn.Dense(img_size[-1]*img_size[-1]*n_ch[-1], activation=act_type))\n", 180 | " self.decoder.add(Reshape((n_ch[-1], img_size[-1], img_size[-1])))\n", 181 | " for i in range(n_layers-1, -1, -1):\n", 182 | "\n", 183 | " if i == 0:\n", 184 | " act_type = 'sigmoid'\n", 185 | " ch = 1\n", 186 | " else:\n", 187 | " ch = n_ch[i-1]\n", 188 | " \n", 189 | " self.decoder.add(nn.Conv2DTranspose(channels=ch, kernel_size=ks[i], strides=stride[i], \n", 190 | " padding=pad[i], activation=act_type, use_bias=use_bias))\n", 191 | " if i >0 and do_batch_norm:\n", 192 | " self.encoder.add(nn.BatchNorm())\n", 193 | "\n", 194 | " \n", 195 | " def hybrid_forward(self, F, x):\n", 196 | " h = self.encoder(x)\n", 197 | " #print(h)\n", 198 | " mu_lv = F.split(h, axis=1, num_outputs=2)\n", 199 | " mu = mu_lv[0]\n", 200 | " lv = mu_lv[1]\n", 201 | " self.mu = mu\n", 202 | " #eps = F.random_normal(loc=0, scale=1, shape=mu.shape, ctx=model_ctx) \n", 203 | " # this would work fine only for nd (i.e. non-hybridized block)\n", 204 | " eps = F.random_normal(loc=0, scale=1, shape=(self.batch_size, self.n_latent), ctx=model_ctx)\n", 205 | " z = mu + F.exp(0.5*lv)*eps\n", 206 | " y = self.decoder(z)\n", 207 | "\n", 208 | " y = y.reshape((0,-1))\n", 209 | " x = x.reshape((0,-1))\n", 210 | " self.output = y\n", 211 | " \n", 212 | " KL = 0.5*F.sum(1+lv-mu*mu-F.exp(lv),axis=1)\n", 213 | " logloss = F.sum(x*F.log(y+self.soft_zero)+ (1-x)*F.log(1-y+self.soft_zero), axis=1)\n", 214 | " loss = -logloss-KL\n", 215 | " \n", 216 | " return loss" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 7, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "n_latent=100\n", 226 | "n_layers=3 # num of dense layers in encoder and decoder respectively\n", 227 | "model_prefix = 'vaecnn_gluon.params'\n", 228 | "\n", 229 | "net = VAE(n_latent=n_latent, n_layers=n_layers, batch_size=batch_size)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 8, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "x = nd.random.normal(shape=(100,1,28,28), ctx=model_ctx)\n", 239 | "net.collect_params().initialize(mx.init.Xavier(), ctx=model_ctx)" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 9, 245 | "metadata": {}, 246 | "outputs": [ 247 | { 248 | "data": { 249 | "text/plain": [ 250 | "(100, 1, 28, 28)" 251 | ] 252 | }, 253 | "execution_count": 9, 254 | "metadata": {}, 255 | "output_type": "execute_result" 256 | } 257 | ], 258 | "source": [ 259 | "#net.encoder(x).shape\n", 260 | "net.decoder(net.encoder(x)).shape" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 10, 266 | "metadata": {}, 267 | "outputs": [ 268 | { 269 | "data": { 270 | "text/plain": [ 271 | "\n", 272 | "[ 546.06201172 545.53796387 544.4074707 544.19537354 545.38684082\n", 273 | " 544.65002441 544.71551514 545.89727783 544.57885742 543.54858398\n", 274 | " 544.64135742 543.77264404 545.13153076 544.36578369 545.32037354\n", 275 | " 543.809021 545.52612305 545.95800781 544.73297119 544.39874268\n", 276 | " 545.74047852 543.75219727 544.55889893 545.34753418 543.24841309\n", 277 | " 545.53546143 544.65734863 544.09979248 545.3168335 544.69390869\n", 278 | " 546.03131104 545.4463501 545.23590088 546.28057861 545.41009521\n", 279 | " 544.92681885 544.93328857 544.43426514 544.87414551 545.15942383\n", 280 | " 545.28277588 544.64105225 544.15570068 545.71075439 545.38598633\n", 281 | " 545.50579834 544.54016113 543.35552979 545.65466309 545.78479004\n", 282 | " 545.89532471 544.95196533 545.56799316 543.90716553 544.16644287\n", 283 | " 544.49456787 545.82507324 545.60534668 545.86523438 543.34423828\n", 284 | " 545.19366455 544.73553467 543.7041626 543.7723999 544.83898926\n", 285 | " 545.00488281 544.32824707 546.16003418 544.12530518 545.0111084\n", 286 | " 545.12792969 545.07617188 544.95513916 544.73693848 543.88604736\n", 287 | " 545.25732422 546.05053711 547.10754395 544.11456299 544.38763428\n", 288 | " 543.24206543 544.23150635 544.25482178 543.01385498 543.39642334\n", 289 | " 544.87512207 546.18731689 545.75134277 545.22277832 544.31976318\n", 290 | " 546.71716309 545.74285889 545.81549072 546.0269165 545.50567627\n", 291 | " 545.02563477 545.15045166 545.81072998 546.50354004 545.23431396]\n", 292 | "" 293 | ] 294 | }, 295 | "execution_count": 10, 296 | "metadata": {}, 297 | "output_type": "execute_result" 298 | } 299 | ], 300 | "source": [ 301 | "net = VAE(n_latent=n_latent, n_layers=n_layers, batch_size=batch_size)\n", 302 | "net.collect_params().initialize(mx.init.Xavier(), ctx=model_ctx)\n", 303 | "net(x)" 304 | ] 305 | }, 306 | { 307 | "cell_type": "markdown", 308 | "metadata": {}, 309 | "source": [ 310 | "# Model training" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 11, 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "net = VAE(n_latent=n_latent, n_layers=n_layers, batch_size=batch_size)\n", 320 | "net.collect_params().initialize(mx.init.Xavier(), ctx=model_ctx)\n", 321 | "net.hybridize()\n", 322 | "trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': .0005})" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 12, 328 | "metadata": {}, 329 | "outputs": [ 330 | { 331 | "data": { 332 | "application/vnd.jupyter.widget-view+json": { 333 | "model_id": "f61ebabcdb044432ac3bfe6e207d99c7", 334 | "version_major": 2, 335 | "version_minor": 0 336 | }, 337 | "text/html": [ 338 | "

Failed to display Jupyter Widget of type HBox.

\n", 339 | "

\n", 340 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 341 | " that the widgets JavaScript is still loading. If this message persists, it\n", 342 | " likely means that the widgets JavaScript library is either not installed or\n", 343 | " not enabled. See the Jupyter\n", 344 | " Widgets Documentation for setup instructions.\n", 345 | "

\n", 346 | "

\n", 347 | " If you're reading this message in another frontend (for example, a static\n", 348 | " rendering on GitHub or NBViewer),\n", 349 | " it may mean that your frontend doesn't currently support widgets.\n", 350 | "

\n" 351 | ], 352 | "text/plain": [ 353 | "HBox(children=(IntProgress(value=0, description='epochs: ', max=50), HTML(value='')))" 354 | ] 355 | }, 356 | "metadata": {}, 357 | "output_type": "display_data" 358 | }, 359 | { 360 | "name": "stdout", 361 | "output_type": "stream", 362 | "text": [ 363 | "Epoch 0, Training loss 199.280682322, Validation loss 146.589102173\n", 364 | "Epoch 5, Training loss 108.088542442, Validation loss 106.836221542\n", 365 | "Epoch 10, Training loss 103.571974131, Validation loss 102.817866974\n", 366 | "Epoch 15, Training loss 101.714880193, Validation loss 101.300068817\n", 367 | "Epoch 20, Training loss 100.513956617, Validation loss 100.141811981\n", 368 | "Epoch 25, Training loss 99.6662532298, Validation loss 99.4761208344\n", 369 | "Epoch 30, Training loss 99.0791517258, Validation loss 99.0510205841\n", 370 | "Epoch 35, Training loss 98.6449058151, Validation loss 98.5340620422\n", 371 | "Epoch 40, Training loss 98.2070299784, Validation loss 98.3140554047\n", 372 | "Epoch 45, Training loss 97.8333548228, Validation loss 97.715945282\n", 373 | "\n", 374 | "Time elapsed: 248.27s\n" 375 | ] 376 | } 377 | ], 378 | "source": [ 379 | "n_epoch = 50\n", 380 | "print_period = n_epoch // 10\n", 381 | "start = time.time()\n", 382 | "\n", 383 | "training_loss = []\n", 384 | "validation_loss = []\n", 385 | "for epoch in tqdm_notebook(range(n_epoch), desc='epochs'):\n", 386 | " epoch_loss = 0\n", 387 | " epoch_val_loss = 0\n", 388 | " \n", 389 | " train_iter.reset()\n", 390 | " test_iter.reset()\n", 391 | " \n", 392 | " n_batch_train = 0\n", 393 | " for batch in train_iter:\n", 394 | " n_batch_train +=1\n", 395 | " data = batch.data[0].as_in_context(model_ctx)\n", 396 | " with autograd.record():\n", 397 | " loss = net(data)\n", 398 | " loss.backward()\n", 399 | " trainer.step(data.shape[0])\n", 400 | " epoch_loss += nd.mean(loss).asscalar()\n", 401 | " \n", 402 | " n_batch_val = 0\n", 403 | " for batch in test_iter:\n", 404 | " n_batch_val +=1\n", 405 | " data = batch.data[0].as_in_context(model_ctx)\n", 406 | " loss = net(data)\n", 407 | " epoch_val_loss += nd.mean(loss).asscalar()\n", 408 | " \n", 409 | " epoch_loss /= n_batch_train\n", 410 | " epoch_val_loss /= n_batch_val\n", 411 | " \n", 412 | " training_loss.append(epoch_loss)\n", 413 | " validation_loss.append(epoch_val_loss)\n", 414 | " \n", 415 | " if epoch % max(print_period,1) == 0:\n", 416 | " tqdm.write('Epoch %d, Training loss %s, Validation loss %s' % (epoch, epoch_loss, epoch_val_loss))\n", 417 | " \n", 418 | "end = time.time()\n", 419 | "print('Time elapsed: {:.2f}s'.format(end - start))" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 13, 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [ 428 | "net.save_params(model_prefix)" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": 14, 434 | "metadata": {}, 435 | "outputs": [ 436 | { 437 | "data": { 438 | "text/plain": [ 439 | "" 440 | ] 441 | }, 442 | "execution_count": 14, 443 | "metadata": {}, 444 | "output_type": "execute_result" 445 | }, 446 | { 447 | "data": { 448 | "image/png": "\n", 449 | "text/plain": [ 450 | "" 451 | ] 452 | }, 453 | "metadata": {}, 454 | "output_type": "display_data" 455 | } 456 | ], 457 | "source": [ 458 | "batch_x = np.linspace(1, n_epoch, len(training_loss))\n", 459 | "plt.plot(batch_x, -1*np.array(training_loss))\n", 460 | "plt.plot(batch_x, -1*np.array(validation_loss))\n", 461 | "plt.legend(['train', 'valid'])" 462 | ] 463 | }, 464 | { 465 | "cell_type": "markdown", 466 | "metadata": {}, 467 | "source": [ 468 | "# Load model" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": 15, 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [ 477 | "net2 = VAE(n_latent=n_latent, n_layers=n_layers, batch_size=batch_size)\n", 478 | "net2.load_params(model_prefix, ctx=model_ctx)" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": {}, 484 | "source": [ 485 | "# Visualizing reconstruction quality" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 16, 491 | "metadata": {}, 492 | "outputs": [], 493 | "source": [ 494 | "test_iter.reset()\n", 495 | "test_batch = test_iter.next()\n", 496 | "net2(test_batch.data[0].as_in_context(model_ctx))\n", 497 | "result = net2.output.asnumpy()\n", 498 | "original = test_batch.data[0].asnumpy()" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": 17, 504 | "metadata": {}, 505 | "outputs": [ 506 | { 507 | "data": { 508 | "image/png": "\n", 509 | "text/plain": [ 510 | "" 511 | ] 512 | }, 513 | "metadata": {}, 514 | "output_type": "display_data" 515 | } 516 | ], 517 | "source": [ 518 | "n_samples = 10\n", 519 | "idx = np.random.choice(batch_size, n_samples)\n", 520 | "_, axarr = plt.subplots(2, n_samples, figsize=(16,4))\n", 521 | "for i,j in enumerate(idx):\n", 522 | " axarr[0,i].imshow(original[j].reshape((28,28)), cmap='Greys')\n", 523 | " if i==0:\n", 524 | " axarr[0,i].set_title('original')\n", 525 | " #axarr[0,i].axis('off')\n", 526 | " axarr[0,i].get_xaxis().set_ticks([])\n", 527 | " axarr[0,i].get_yaxis().set_ticks([])\n", 528 | "\n", 529 | " axarr[1,i].imshow(result[j].reshape((28,28)), cmap='Greys')\n", 530 | " if i==0:\n", 531 | " axarr[1,i].set_title('reconstruction')\n", 532 | " #axarr[1,i].axis('off')\n", 533 | " axarr[1,i].get_xaxis().set_ticks([])\n", 534 | " axarr[1,i].get_yaxis().set_ticks([])\n", 535 | "plt.show()" 536 | ] 537 | }, 538 | { 539 | "cell_type": "markdown", 540 | "metadata": {}, 541 | "source": [ 542 | "# Visualizing latent space (when it is 2D)" 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": 18, 548 | "metadata": {}, 549 | "outputs": [], 550 | "source": [ 551 | "n_batches = 10\n", 552 | "counter = 0\n", 553 | "results = []\n", 554 | "labels = []\n", 555 | "for batch in test_iter:\n", 556 | " net2(batch.data[0].as_in_context(model_ctx))\n", 557 | " results.append(net2.mu.asnumpy())\n", 558 | " labels.append(batch.label[0].asnumpy())\n", 559 | " counter +=1\n", 560 | " if counter >= n_batches:\n", 561 | " break" 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": 19, 567 | "metadata": {}, 568 | "outputs": [], 569 | "source": [ 570 | "result= np.vstack(results)\n", 571 | "labels = np.hstack(labels)" 572 | ] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "execution_count": 20, 577 | "metadata": {}, 578 | "outputs": [], 579 | "source": [ 580 | "if result.shape[1]==2:\n", 581 | " from scipy.special import ndtri\n", 582 | " from scipy.stats import norm\n", 583 | "\n", 584 | " fig, axarr = plt.subplots(1,2, figsize=(10,4))\n", 585 | " im=axarr[0].scatter(result[:, 0], result[:, 1], c=labels, alpha=0.6, cmap='Paired')\n", 586 | " axarr[0].set_title('scatter plot of $\\mu$')\n", 587 | " axarr[0].axis('equal')\n", 588 | " fig.colorbar(im, ax=axarr[0])\n", 589 | "\n", 590 | " im=axarr[1].scatter(norm.cdf(result[:, 0]), norm.cdf(result[:, 1]), c=labels, alpha=0.6, cmap='Paired')\n", 591 | " axarr[1].set_title('scatter plot of $\\mu$ on norm.cdf() transformed coordinates')\n", 592 | " axarr[1].axis('equal')\n", 593 | " fig.colorbar(im, ax=axarr[1])\n", 594 | " plt.tight_layout()\n", 595 | " if output_fig:\n", 596 | " plt.savefig('2d_latent_space_for_test_samples.png')" 597 | ] 598 | }, 599 | { 600 | "cell_type": "markdown", 601 | "metadata": {}, 602 | "source": [ 603 | "# Sample latent space and generate images" 604 | ] 605 | }, 606 | { 607 | "cell_type": "markdown", 608 | "metadata": {}, 609 | "source": [ 610 | "## Random sampling" 611 | ] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "execution_count": 21, 616 | "metadata": {}, 617 | "outputs": [], 618 | "source": [ 619 | "n_samples = 10\n", 620 | "zsamples = nd.array(np.random.randn(n_samples*n_samples, n_latent))" 621 | ] 622 | }, 623 | { 624 | "cell_type": "code", 625 | "execution_count": 22, 626 | "metadata": {}, 627 | "outputs": [], 628 | "source": [ 629 | "images = net2.decoder(zsamples.as_in_context(model_ctx)).asnumpy()" 630 | ] 631 | }, 632 | { 633 | "cell_type": "code", 634 | "execution_count": 23, 635 | "metadata": {}, 636 | "outputs": [ 637 | { 638 | "data": { 639 | "image/png": "\n", 640 | "text/plain": [ 641 | "" 642 | ] 643 | }, 644 | "metadata": {}, 645 | "output_type": "display_data" 646 | } 647 | ], 648 | "source": [ 649 | "canvas = np.empty((28*n_samples, 28*n_samples))\n", 650 | "for i, img in enumerate(images):\n", 651 | " x = i // n_samples\n", 652 | " y = i % n_samples\n", 653 | " canvas[(n_samples-y-1)*28:(n_samples-y)*28, x*28:(x+1)*28] = img.reshape(28, 28)\n", 654 | "plt.figure(figsize=(4, 4)) \n", 655 | "plt.imshow(canvas, origin=\"upper\", cmap=\"Greys\")\n", 656 | "plt.axis('off')\n", 657 | "plt.tight_layout()\n", 658 | "if output_fig:\n", 659 | " plt.savefig('generated_samples_with_{}D_latent_space.png'.format(n_latent))" 660 | ] 661 | }, 662 | { 663 | "cell_type": "markdown", 664 | "metadata": {}, 665 | "source": [ 666 | "## Grid scan 2D latent space" 667 | ] 668 | }, 669 | { 670 | "cell_type": "code", 671 | "execution_count": 24, 672 | "metadata": {}, 673 | "outputs": [], 674 | "source": [ 675 | "if n_latent==2: \n", 676 | " n_pts = 20\n", 677 | "\n", 678 | " idx = np.arange(0, n_pts)\n", 679 | "\n", 680 | " x = np.linspace(norm.cdf(-3), norm.cdf(3),n_pts)\n", 681 | " x = ndtri(x)\n", 682 | "\n", 683 | " x_grid = np.array(np.meshgrid(*[i for i in np.matlib.repmat(x,n_latent,1)]))\n", 684 | " id_grid = np.array(np.meshgrid(*[i for i in np.matlib.repmat(idx,n_latent,1)]))\n", 685 | "\n", 686 | " zsamples = nd.array(x_grid.reshape((n_latent, -1)).transpose())\n", 687 | " zsamples_id = id_grid.reshape((n_latent, -1)).transpose()\n", 688 | "\n", 689 | " images = net2.decoder(zsamples.as_in_context(model_ctx)).asnumpy()\n", 690 | "\n", 691 | " #plot\n", 692 | " canvas = np.empty((28*n_pts, 28*n_pts))\n", 693 | " for i, img in enumerate(images):\n", 694 | " #plt.imshow(img.reshape(28,28))\n", 695 | " x, y = zsamples_id[i]\n", 696 | " canvas[(n_pts-y-1)*28:(n_pts-y)*28, x*28:(x+1)*28] = img.reshape(28, 28)\n", 697 | " plt.figure(figsize=(6, 6)) \n", 698 | " plt.imshow(canvas, origin=\"upper\", cmap=\"Greys\")\n", 699 | " plt.axis('off')\n", 700 | " plt.tight_layout()\n", 701 | " if output_fig:\n", 702 | " plt.savefig('2d_latent_space_scan_for_generation.png')" 703 | ] 704 | } 705 | ], 706 | "metadata": { 707 | "kernelspec": { 708 | "display_name": "Python 3", 709 | "language": "python", 710 | "name": "python3" 711 | }, 712 | "language_info": { 713 | "codemirror_mode": { 714 | "name": "ipython", 715 | "version": 3 716 | }, 717 | "file_extension": ".py", 718 | "mimetype": "text/x-python", 719 | "name": "python", 720 | "nbconvert_exporter": "python", 721 | "pygments_lexer": "ipython3", 722 | "version": "3.5.2" 723 | }, 724 | "widgets": { 725 | "state": { 726 | "013831ad95a14b50989742a490b65ad0": { 727 | "views": [] 728 | }, 729 | "060dbcfa0dbf4ec79c19a6082dfcde16": { 730 | "views": [] 731 | }, 732 | "064f5c79b47748cc85671b2dfe3a7b11": { 733 | "views": [] 734 | }, 735 | "075482b08bf8473d84a539c6ed4a2cd5": { 736 | "views": [] 737 | }, 738 | "0d438ff52cc74867852a17039f872b30": { 739 | "views": [] 740 | }, 741 | "0de6820e97e64c0e9ba1713cb31270c5": { 742 | "views": [] 743 | }, 744 | "11ceed56010c4035a8e00f0e777a32ef": { 745 | "views": [] 746 | }, 747 | "128439858fca48bf896e3bc048e434c6": { 748 | "views": [] 749 | }, 750 | "17badb0485644615ace649fbf7221b2b": { 751 | "views": [] 752 | }, 753 | "252bad49b82041cf937f3957f2738138": { 754 | "views": [] 755 | }, 756 | "2757715fdbed476fae6764d36440abf4": { 757 | "views": [] 758 | }, 759 | "28b36582d56742e991cbb4eeaaaa6815": { 760 | "views": [] 761 | }, 762 | "2c15dc27263a4aeba4ad716d26a6eb51": { 763 | "views": [] 764 | }, 765 | "2c3383be9c2b42139fa75dcaa0f76362": { 766 | "views": [] 767 | }, 768 | "2d2404d2cbea46eca2df100d4a216ab0": { 769 | "views": [] 770 | }, 771 | "312a7d065faf42efa6d590abd99af6e4": { 772 | "views": [] 773 | }, 774 | "3609be49f4754b619ee5f9b8dd5655c2": { 775 | "views": [] 776 | }, 777 | "367462141a584d36b8c116e322840ec2": { 778 | "views": [] 779 | }, 780 | "3bcccfd7f2a048d1ba7277cbfef35070": { 781 | "views": [] 782 | }, 783 | "3cc84227a2994ae68724d39f91fe4888": { 784 | "views": [] 785 | }, 786 | "3d2398ca4f0444a28f20c5526cc17f60": { 787 | "views": [] 788 | }, 789 | "42d4a8624c5d43cdafe4d95b6093088d": { 790 | "views": [] 791 | }, 792 | "430e7722f893454fb6fb6d7067038a52": { 793 | "views": [] 794 | }, 795 | "45bce29afb8e45f989b2a7ce35bafc06": { 796 | "views": [] 797 | }, 798 | "4d3698d1eb7445f39af9a8551e125574": { 799 | "views": [] 800 | }, 801 | "59278d56d73b42a3b54cc5486ec9d28a": { 802 | "views": [] 803 | }, 804 | "5a88eafb8bd041df90e5b03b9dbd2930": { 805 | "views": [] 806 | }, 807 | "5cac8f55cbd447bd8e04714c0be9cfdd": { 808 | "views": [] 809 | }, 810 | "613f4b7c012547f8809bebac99130b1f": { 811 | "views": [] 812 | }, 813 | "66fe705008014e34bd1b66a97ca00a52": { 814 | "views": [] 815 | }, 816 | "693e5755d56d453ea1b30f08bd3cb264": { 817 | "views": [] 818 | }, 819 | "6ccbd5c95fbb4cd3b610be1af7b0cb28": { 820 | "views": [] 821 | }, 822 | "70bd4dc3920145df96c0ada89dd26eb7": { 823 | "views": [] 824 | }, 825 | "744afe6e043f41b4ae79c8638e8e983f": { 826 | "views": [] 827 | }, 828 | "7843ec29c508499bb945789fdb7b482c": { 829 | "views": [] 830 | }, 831 | "78d7e99ed96e463f897c738a52f30baf": { 832 | "views": [ 833 | { 834 | "cell_index": 13 835 | } 836 | ] 837 | }, 838 | "7d524cdb5a314099b9c5a4152c6b129e": { 839 | "views": [] 840 | }, 841 | "7d5f644a33f6467bba4d1b5bd1962c0f": { 842 | "views": [] 843 | }, 844 | "7e879703f4574bac9de4137d1155c8ce": { 845 | "views": [] 846 | }, 847 | "7e8d33c949fb426e81125f082b7deb7c": { 848 | "views": [] 849 | }, 850 | "81125a7384804e539de58b79c381d7fa": { 851 | "views": [] 852 | }, 853 | "81633120e61a47c49e7e7cde277d7af9": { 854 | "views": [] 855 | }, 856 | "84734c2e88a0466e9492e93747befc49": { 857 | "views": [] 858 | }, 859 | "86ed6062cea3471eb52e395764c1edb4": { 860 | "views": [ 861 | { 862 | "cell_index": 13 863 | } 864 | ] 865 | }, 866 | "94ad29c1c4644f45a4749874ddd0c1fc": { 867 | "views": [] 868 | }, 869 | "953eb61e1af74f32998d6f61967006cb": { 870 | "views": [] 871 | }, 872 | "98b9cda7c3ff48df8639759fd249bee8": { 873 | "views": [] 874 | }, 875 | "9912b3cefd69439cae64925eedf97ee3": { 876 | "views": [] 877 | }, 878 | "9989ee405297420da7ba25fc0a2068cc": { 879 | "views": [ 880 | { 881 | "cell_index": 13 882 | } 883 | ] 884 | }, 885 | "9a0de86e07334d27a052b2f4edd67ed1": { 886 | "views": [] 887 | }, 888 | "9c74970b79904db087057befbe464d16": { 889 | "views": [] 890 | }, 891 | "9d82a3bd77894311b87a19b044086018": { 892 | "views": [] 893 | }, 894 | "9d9ce67edd2f46c6969644ce18c40464": { 895 | "views": [] 896 | }, 897 | "9eb346c6adf34137a0e4e2cc5b3a9df0": { 898 | "views": [] 899 | }, 900 | "9ec367a05b8f41d69afc7923c11b632c": { 901 | "views": [] 902 | }, 903 | "9fa6d568f8a549d8811cccd8649f9ce2": { 904 | "views": [] 905 | }, 906 | "a19d63a0324e4cc2bb3af7dc331b5abc": { 907 | "views": [] 908 | }, 909 | "a3636a2cce744c75bb7d1bd672aaa1a9": { 910 | "views": [] 911 | }, 912 | "a3e4aa4b2ec94dfba8e07f2641229ed1": { 913 | "views": [] 914 | }, 915 | "a64837790ef04869a6b399251dcea040": { 916 | "views": [] 917 | }, 918 | "aa41497993db4f1e8ce8cb8e31751d52": { 919 | "views": [] 920 | }, 921 | "aab984197cee4e52b8d062bd8cfa9878": { 922 | "views": [] 923 | }, 924 | "ab33b20130ba4158b90b2c93cdb2116c": { 925 | "views": [] 926 | }, 927 | "ab3f83b6a6b84087b565c34a2a76aa36": { 928 | "views": [] 929 | }, 930 | "ab91869bd6e347b0adbcb3fb3a82db14": { 931 | "views": [ 932 | { 933 | "cell_index": 13 934 | } 935 | ] 936 | }, 937 | "b5537aa47c3d4088beb0e83b6cc35b2c": { 938 | "views": [] 939 | }, 940 | "b810da789f3442a5a93b30914e2b6cb0": { 941 | "views": [] 942 | }, 943 | "ba284efbb6444b74808d8ad20d4ab2d8": { 944 | "views": [] 945 | }, 946 | "bca4373a0d0e4741b9391e51dfe05525": { 947 | "views": [] 948 | }, 949 | "be8f818802844d00b6dec3892a93fe01": { 950 | "views": [] 951 | }, 952 | "beffdcc6f9514a30a71ec609c0133955": { 953 | "views": [] 954 | }, 955 | "bf6b616cbbc949dd874c3e0e4e49fe83": { 956 | "views": [] 957 | }, 958 | "c18c5aed11b647a6b0d4efcbd6b895e7": { 959 | "views": [] 960 | }, 961 | "c1da8befa7c9434099a6bd883e6039f7": { 962 | "views": [] 963 | }, 964 | "c5d5c7c363d5436688ed3780ccb5f437": { 965 | "views": [] 966 | }, 967 | "c623d679576c4d2199122117673c7d2d": { 968 | "views": [] 969 | }, 970 | "c6b3e5ff5206400d92b012428dc13ada": { 971 | "views": [] 972 | }, 973 | "cac058af8fe84d9090488f97eaf6e2c6": { 974 | "views": [] 975 | }, 976 | "cb4662815adb4cfd92fc21c09c62b242": { 977 | "views": [] 978 | }, 979 | "cfd34e25998f440193b51db9aa1aee1a": { 980 | "views": [] 981 | }, 982 | "d25ba13396474cfbba1355b34324f2c8": { 983 | "views": [] 984 | }, 985 | "d39bdd8023c746b4a823012e59c3e695": { 986 | "views": [] 987 | }, 988 | "d5239161f55643aa96a7f368202b65ff": { 989 | "views": [] 990 | }, 991 | "d6c0abd6d4c4438eb7a1341001800aa6": { 992 | "views": [] 993 | }, 994 | "d9e16f99be3f4a16a2a2d4afb41052a5": { 995 | "views": [] 996 | }, 997 | "da7ed38c90e7444fbd8cb05ba7aa29ba": { 998 | "views": [] 999 | }, 1000 | "db876aad73d44ec684ff8617bb216bb6": { 1001 | "views": [] 1002 | }, 1003 | "dc30ca57c7484d6aa2fe395006333eee": { 1004 | "views": [ 1005 | { 1006 | "cell_index": 13 1007 | } 1008 | ] 1009 | }, 1010 | "e0ee3a3f1333466980d7d56ebee03560": { 1011 | "views": [] 1012 | }, 1013 | "e53d57d91ae346188283343ae5aa75d7": { 1014 | "views": [] 1015 | }, 1016 | "ebe196e6565e4193b39f079243a8b71e": { 1017 | "views": [] 1018 | }, 1019 | "ec3003f4855e4b0796c6929aa07c9bdf": { 1020 | "views": [ 1021 | { 1022 | "cell_index": 13 1023 | } 1024 | ] 1025 | }, 1026 | "ed07ff00310e42f190452577f1144af3": { 1027 | "views": [] 1028 | }, 1029 | "f4218c2fd6a7434181f548e28ffbdee6": { 1030 | "views": [] 1031 | }, 1032 | "f4b6303e7c6c41bb8cb3a0f0e87dfbef": { 1033 | "views": [] 1034 | }, 1035 | "f6e60e625417416fa45c478d419e820b": { 1036 | "views": [] 1037 | }, 1038 | "f747d9c58a844bf5b5939e5f6b7f3f77": { 1039 | "views": [] 1040 | }, 1041 | "f97975725348441ebd204396e52b14ec": { 1042 | "views": [ 1043 | { 1044 | "cell_index": 13 1045 | } 1046 | ] 1047 | }, 1048 | "fc4f654f915a4bd49cdbb664ee2b74cf": { 1049 | "views": [] 1050 | }, 1051 | "fca1f6e81f9942b78671ffced0a6901e": { 1052 | "views": [] 1053 | }, 1054 | "feb02650db954c01aefdc7024eac9510": { 1055 | "views": [] 1056 | }, 1057 | "ff9730625878484bbbad310801d7b64d": { 1058 | "views": [] 1059 | } 1060 | }, 1061 | "version": "1.1.2" 1062 | } 1063 | }, 1064 | "nbformat": 4, 1065 | "nbformat_minor": 2 1066 | } 1067 | --------------------------------------------------------------------------------