├── README.md ├── mmd_vae_pytorchver.ipynb ├── mmd_vae_pytorchver_norunlognoimg.ipynb └── plots ├── generation.png ├── scatter-after-training.png ├── scatter-after-training_classic.png ├── scatter-before-training.png └── scatter-before-training_classic.png /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Implementation of MMD Variational Autoencoder 2 | ----------------------------------------------------------------------------------------------------------------------------- 3 | Implementation of the paper [InfoVAE: Information Maximizing Variational Autoencoders](https://arxiv.org/abs/1706.02262) 4 | 5 | The Code has been converted from the TensorFlow [implementation](https://github.com/ShengjiaZhao/MMD-Variational-Autoencoder) by [Shengjia Zhao](https://github.com/ShengjiaZhao) 6 | 7 | The results shown are generated by the given pytorch code. 8 | 9 | "Details and motivation are described in this [paper](https://arxiv.org/abs/1706.02262) or [tutorial](http://szhao.me/2017/06/10/a-tutorial-on-mmd-variational-autoencoders.html)" 10 | 11 | Sample generated after 27 epochs of training on MNIST 12 | 13 | ![mnist](plots/generation.png) 14 | 15 | Scatter Plot on the distribution of labels in the feature space (for 2 dimensional z) 16 | 17 | ## Before Training 18 | ![mnist](plots/scatter-before-training_classic.png) 19 | 20 | ## After Training 21 | ![mnist](plots/scatter-after-training_classic.png) 22 | 23 | -------------------------------------------------------------------------------- /mmd_vae_pytorchver_norunlognoimg.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch\n", 12 | "from torch import nn, optim\n", 13 | "from torch.autograd import Variable\n", 14 | "from torch.nn import functional as F\n", 15 | "from torchvision import datasets, transforms\n", 16 | "import numpy as np\n", 17 | "from matplotlib import pyplot as plt\n", 18 | "import math, os" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": { 25 | "collapsed": true 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "z_dim = 3\n", 30 | "usecuda = True\n", 31 | "idgpu = 2\n", 32 | "epochs = 40\n", 33 | "kwargs = {}" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": { 40 | "collapsed": true 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "train_loader = torch.utils.data.DataLoader(\n", 45 | " datasets.MNIST('./data', train=True, download=True,\n", 46 | " transform=transforms.ToTensor()),\n", 47 | " batch_size=200, shuffle=True, **kwargs)\n", 48 | "test_loader = torch.utils.data.DataLoader(\n", 49 | " datasets.MNIST('./data', train=False,download=True, transform=transforms.ToTensor()),\n", 50 | " batch_size=200, shuffle=True, **kwargs)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "## Encoder and decoder use the DC-GAN architecture" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "### Encoder" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 4, 70 | "metadata": { 71 | "collapsed": true 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "class ChannelsToLinear(nn.Linear):\n", 76 | " \"\"\"Flatten a Variable to 2d and apply Linear layer\"\"\"\n", 77 | " def forward(self, x):\n", 78 | " b = x.size(0)\n", 79 | " return super().forward(x.view(b,-1))\n", 80 | "class Encoder(nn.Module):\n", 81 | " def __init__(self):\n", 82 | " super(Encoder, self).__init__()\n", 83 | " n_filters = 64\n", 84 | " self.conv1 = nn.Conv2d(1, n_filters, 4,2,1)\n", 85 | " self.conv2 = nn.Conv2d(n_filters, n_filters*2, 4, 2,1)\n", 86 | " \n", 87 | " self.toLinear1 = ChannelsToLinear(n_filters*2*7*7, 1024)\n", 88 | " self.fc1 = nn.Linear(1024,z_dim)\n", 89 | " \n", 90 | " self.lrelu = nn.LeakyReLU(negative_slope=0.1)\n", 91 | " \n", 92 | " def forward(self,x):\n", 93 | " h1 = self.lrelu(self.conv1(x))\n", 94 | " h2 = self.lrelu(self.conv2(h1))\n", 95 | " h3 = self.lrelu(self.toLinear1(h2))\n", 96 | " h4 = self.fc1(h3)\n", 97 | " \n", 98 | " return h4\n", 99 | " \n", 100 | "encodermodel = Encoder()\n", 101 | "if usecuda:\n", 102 | " encodermodel.cuda(idgpu) \n", 103 | " \n", 104 | " " 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "### Decoder" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 5, 117 | "metadata": { 118 | "collapsed": true 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "class LinearToChannels2d(nn.Linear):\n", 123 | " \"\"\"Reshape 2d Variable to 4d after Linear layer\"\"\"\n", 124 | " def __init__(self, m, n, w=1, h=None, **kw):\n", 125 | " h = h or w\n", 126 | " super().__init__(m, n*w*h, **kw)\n", 127 | " self.w = w\n", 128 | " self.h = h\n", 129 | " def forward(self, x):\n", 130 | " b = x.size(0)\n", 131 | " return super().forward(x).view(b, -1, self.w, self.h)\n", 132 | " \n", 133 | "class Decoder(nn.Module):\n", 134 | " def __init__(self):\n", 135 | " super(Decoder, self).__init__()\n", 136 | " n_filters = 64\n", 137 | " \n", 138 | " self.fc1 = nn.Linear(z_dim,1024)\n", 139 | " self.LineartoChannel = LinearToChannels2d(1024,n_filters*2,7,7)\n", 140 | " self.conv1 = nn.ConvTranspose2d(n_filters*2,n_filters,4,2,1)\n", 141 | " self.conv2 = nn.ConvTranspose2d(n_filters,1,4,2,1)\n", 142 | " \n", 143 | " self.relu = nn.ReLU()\n", 144 | " self.sigmoid = nn.Sigmoid()\n", 145 | " \n", 146 | " def forward(self,z):\n", 147 | " h1 = self.relu(self.fc1(z))\n", 148 | " h2 = self.relu(self.LineartoChannel(h1))\n", 149 | " \n", 150 | " h3 = self.relu(self.conv1(h2))\n", 151 | " h4 = self.sigmoid(self.conv2(h3))\n", 152 | " \n", 153 | " return h4\n", 154 | " \n", 155 | "decodermodel = Decoder()\n", 156 | "if usecuda:\n", 157 | " decodermodel.cuda(idgpu) " 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "## MMD Loss Function" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 6, 170 | "metadata": { 171 | "collapsed": true 172 | }, 173 | "outputs": [], 174 | "source": [ 175 | "def compute_kernel(x, y):\n", 176 | " x_size = x.shape[0]\n", 177 | " y_size = y.shape[0]\n", 178 | " dim = x.shape[1]\n", 179 | "\n", 180 | " tiled_x = x.view(x_size,1,dim).repeat(1, y_size,1)\n", 181 | " tiled_y = y.view(1,y_size,dim).repeat(x_size, 1,1)\n", 182 | "\n", 183 | " return torch.exp(-torch.mean((tiled_x - tiled_y)**2,dim=2)/dim*1.0)\n", 184 | "\n", 185 | "\n", 186 | "def compute_mmd(x, y):\n", 187 | " x_kernel = compute_kernel(x, x)\n", 188 | " y_kernel = compute_kernel(y, y)\n", 189 | " xy_kernel = compute_kernel(x, y)\n", 190 | " return torch.mean(x_kernel) + torch.mean(y_kernel) - 2*torch.mean(xy_kernel)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 7, 196 | "metadata": { 197 | "collapsed": true 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "# Convert a numpy array of shape [batch_size, height, width, 1] into a displayable array \n", 202 | "# of shape [height*sqrt(batch_size, width*sqrt(batch_size))] by tiling the images\n", 203 | "def convert_to_display(samples):\n", 204 | " cnt, height, width = int(math.floor(math.sqrt(samples.shape[0]))), samples.shape[1], samples.shape[2]\n", 205 | " samples = np.transpose(samples, axes=[1, 0, 2, 3])\n", 206 | " samples = np.reshape(samples, [height, cnt, cnt, width])\n", 207 | " samples = np.transpose(samples, axes=[1, 0, 2, 3])\n", 208 | " samples = np.reshape(samples, [height*cnt, width*cnt])\n", 209 | " return samples" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 10, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "data": { 219 | "text/plain": [ 220 | "['seaborn-talk',\n", 221 | " 'grayscale',\n", 222 | " 'classic',\n", 223 | " 'Solarize_Light2',\n", 224 | " 'seaborn',\n", 225 | " 'seaborn-colorblind',\n", 226 | " 'ggplot',\n", 227 | " 'fivethirtyeight',\n", 228 | " 'seaborn-bright',\n", 229 | " 'seaborn-ticks',\n", 230 | " 'seaborn-pastel',\n", 231 | " 'seaborn-dark',\n", 232 | " 'seaborn-dark-palette',\n", 233 | " 'seaborn-darkgrid',\n", 234 | " 'bmh',\n", 235 | " 'seaborn-white',\n", 236 | " 'seaborn-paper',\n", 237 | " 'fast',\n", 238 | " 'tableau-colorblind10',\n", 239 | " 'seaborn-poster',\n", 240 | " 'dark_background',\n", 241 | " 'seaborn-whitegrid',\n", 242 | " 'seaborn-deep',\n", 243 | " '_classic_test',\n", 244 | " 'seaborn-notebook',\n", 245 | " 'seaborn-muted']" 246 | ] 247 | }, 248 | "execution_count": 10, 249 | "metadata": {}, 250 | "output_type": "execute_result" 251 | } 252 | ], 253 | "source": [ 254 | "plt.style.available" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 13, 260 | "metadata": { 261 | "collapsed": true 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "from mpl_toolkits.mplot3d import Axes3D" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "## Scatter Plot before training the VAE" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 1, 278 | "metadata": { 279 | "scrolled": false 280 | }, 281 | "outputs": [], 282 | "source": [ 283 | "# If latent z is 2-dimensional we visualize it by plotting latent z of different digits in different colors\n", 284 | "if z_dim == 3:\n", 285 | " z_list, label_list = [], []\n", 286 | " test_batch_size = 500\n", 287 | " #for i in range(20):\n", 288 | " i = 1\n", 289 | " for batch_idx, (test_x, test_y) in enumerate(test_loader):\n", 290 | " if(i>20):\n", 291 | " break\n", 292 | " test_x= Variable(test_x)\n", 293 | " if(usecuda):\n", 294 | " test_x = test_x.cuda(idgpu)\n", 295 | " z = encodermodel(test_x) \n", 296 | " z_list.append(z.data.cpu())\n", 297 | " label_list.append(test_y)\n", 298 | " i = i+1\n", 299 | " z = np.concatenate(z_list, axis=0)\n", 300 | " label = np.concatenate(label_list)\n", 301 | " \n", 302 | " fig = plt.figure()\n", 303 | " ax = fig.add_subplot(111, projection='3d')\n", 304 | " \n", 305 | "# plt.style.use('_classic_test')\n", 306 | " ax.scatter(z[:, 0], z[:, 1],z[:, 2], c=label)\n", 307 | " plt.show()" 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "metadata": {}, 313 | "source": [ 314 | "## Training the VAE" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 2, 320 | "metadata": { 321 | "scrolled": true 322 | }, 323 | "outputs": [], 324 | "source": [ 325 | "optimizerencoder = optim.Adam(encodermodel.parameters(), lr=1e-3)\n", 326 | "optimizerdecoder = optim.Adam(decodermodel.parameters(), lr=1e-3)\n", 327 | "for i in range(epochs):\n", 328 | "\n", 329 | " for batch_idx, (train_x, _) in enumerate(train_loader): \n", 330 | " train_x= Variable(train_x)\n", 331 | " \n", 332 | " true_samples = torch.randn((len(train_x),z_dim))\n", 333 | " true_samples = Variable(true_samples)\n", 334 | " \n", 335 | " if(usecuda):\n", 336 | " train_x = train_x.cuda(idgpu)\n", 337 | " true_samples = true_samples.cuda(idgpu)\n", 338 | " \n", 339 | " optimizerencoder.zero_grad()\n", 340 | " optimizerdecoder.zero_grad()\n", 341 | " \n", 342 | " train_z = encodermodel(train_x)\n", 343 | " \n", 344 | " train_xr = decodermodel(train_z)\n", 345 | " \n", 346 | " loss_mmd = compute_mmd(true_samples, train_z)\n", 347 | " loss_nll = torch.mean((train_xr - train_x)**2)\n", 348 | " \n", 349 | " loss = loss_nll + loss_mmd\n", 350 | " \n", 351 | " loss.backward()\n", 352 | " \n", 353 | " \n", 354 | " optimizerencoder.step()\n", 355 | " optimizerdecoder.step()\n", 356 | " \n", 357 | " if(batch_idx%100 == 0):\n", 358 | " print(\"Epoch %d : Negative log likelihood is %f, mmd loss is %f\" % (i,loss_nll.data[0], loss_mmd.data[0]))\n", 359 | " \n", 360 | " \n", 361 | " # show images\n", 362 | " gen_z = Variable(torch.randn((100, z_dim)))\n", 363 | " if(usecuda):\n", 364 | " gen_z = gen_z.cuda(idgpu)\n", 365 | " samples = decodermodel(gen_z)\n", 366 | " samples =samples.view(100,28,28,1)\n", 367 | " plt.imshow(convert_to_display(samples.data), cmap='Greys_r')\n", 368 | " plt.show() \n", 369 | " " 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 26, 375 | "metadata": {}, 376 | "outputs": [], 377 | "source": [ 378 | "# %matplotlib tk" 379 | ] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "metadata": {}, 384 | "source": [ 385 | "## Scatter Plot after training the VAE" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 3, 391 | "metadata": {}, 392 | "outputs": [], 393 | "source": [ 394 | "# If latent z is 2-dimensional we visualize it by plotting latent z of different digits in different colors\n", 395 | "if z_dim == 3:\n", 396 | " z_list, label_list = [], []\n", 397 | " test_batch_size = 500\n", 398 | " #for i in range(20):\n", 399 | " i = 1\n", 400 | " for batch_idx, (test_x, test_y) in enumerate(test_loader):\n", 401 | " if(i>20):\n", 402 | " break\n", 403 | " test_x= Variable(test_x)\n", 404 | " if(usecuda):\n", 405 | " test_x = test_x.cuda(idgpu)\n", 406 | " z = encodermodel(test_x) \n", 407 | " z_list.append(z.data.cpu())\n", 408 | " label_list.append(test_y)\n", 409 | " i = i+1\n", 410 | " z = np.concatenate(z_list, axis=0)\n", 411 | " label = np.concatenate(label_list)\n", 412 | "# plt.scatter(z[:, 0], z[:, 1], c=label)\n", 413 | "# plt.show()\n", 414 | " fig = plt.figure()\n", 415 | " ax = fig.add_subplot(111, projection='3d')\n", 416 | " \n", 417 | "# plt.style.use('_classic_test')\n", 418 | " ax.scatter(z[:, 0], z[:, 1],z[:, 2], c=label)\n", 419 | " plt.show()" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "metadata": { 426 | "collapsed": true 427 | }, 428 | "outputs": [], 429 | "source": [] 430 | } 431 | ], 432 | "metadata": { 433 | "kernelspec": { 434 | "display_name": "Python 3", 435 | "language": "python", 436 | "name": "python3" 437 | }, 438 | "language_info": { 439 | "codemirror_mode": { 440 | "name": "ipython", 441 | "version": 3 442 | }, 443 | "file_extension": ".py", 444 | "mimetype": "text/x-python", 445 | "name": "python", 446 | "nbconvert_exporter": "python", 447 | "pygments_lexer": "ipython3", 448 | "version": "3.6.9" 449 | } 450 | }, 451 | "nbformat": 4, 452 | "nbformat_minor": 1 453 | } 454 | -------------------------------------------------------------------------------- /plots/generation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratikm141/MMD-Variational-Autoencoder-Pytorch-InfoVAE/b12c82f24acf26334f83c7b2d1eabfb14f1615a6/plots/generation.png -------------------------------------------------------------------------------- /plots/scatter-after-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratikm141/MMD-Variational-Autoencoder-Pytorch-InfoVAE/b12c82f24acf26334f83c7b2d1eabfb14f1615a6/plots/scatter-after-training.png -------------------------------------------------------------------------------- /plots/scatter-after-training_classic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratikm141/MMD-Variational-Autoencoder-Pytorch-InfoVAE/b12c82f24acf26334f83c7b2d1eabfb14f1615a6/plots/scatter-after-training_classic.png -------------------------------------------------------------------------------- /plots/scatter-before-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratikm141/MMD-Variational-Autoencoder-Pytorch-InfoVAE/b12c82f24acf26334f83c7b2d1eabfb14f1615a6/plots/scatter-before-training.png -------------------------------------------------------------------------------- /plots/scatter-before-training_classic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pratikm141/MMD-Variational-Autoencoder-Pytorch-InfoVAE/b12c82f24acf26334f83c7b2d1eabfb14f1615a6/plots/scatter-before-training_classic.png --------------------------------------------------------------------------------