├── README.md └── VAEtoy2d.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # VAETutorial 2 | 3 | The accompanying Jupyter notebook to [ "A Tutorial on VAEs: From Bayes' Rule to Lossless Compression"](https://arxiv.org/abs/2006.10273). 4 | -------------------------------------------------------------------------------- /VAEtoy2d.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "VAEtoy2d.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "machine_shape": "hm", 10 | "authorship_tag": "ABX9TyNBcuyBiME1uZ4tlpCjCL8z", 11 | "include_colab_link": true 12 | }, 13 | "kernelspec": { 14 | "name": "python3", 15 | "display_name": "Python 3" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "view-in-github", 24 | "colab_type": "text" 25 | }, 26 | "source": [ 27 | "\"Open" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "RORJ6vRrPVlL", 34 | "colab_type": "text" 35 | }, 36 | "source": [ 37 | "##Typical VAEs\n", 38 | "This notebook is a walk-through of the code to produce the results in Section 6. First, let us see what our ground truth datasets look like. " 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "metadata": { 44 | "id": "fsDv1FUqPT88", 45 | "colab_type": "code", 46 | "colab": { 47 | "base_uri": "https://localhost:8080/", 48 | "height": 34 49 | }, 50 | "outputId": "fffaf030-5fc9-404c-f28a-dd100243d371" 51 | }, 52 | "source": [ 53 | "import numpy as np\n", 54 | "import matplotlib.pyplot as plt\n", 55 | "import matplotlib.colors as col\n", 56 | "\n", 57 | "def plot_gt(data):\n", 58 | " limit = 0.5\n", 59 | " step = 1/1024.0\n", 60 | " pixels = int(2*limit/step)\n", 61 | " grid = np.array([[a, b] for a in np.arange(-limit, limit, step) for b in np.arange(-limit, limit, step)])\n", 62 | "\n", 63 | " if data =='8gaussians':\n", 64 | " centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)),\n", 65 | " (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2),\n", 66 | " 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))]\n", 67 | " centers = [(limit * x, limit * y) for x, y in centers]\n", 68 | " color = np.zeros((pixels*pixels,3))\n", 69 | " for i,center in enumerate(centers):\n", 70 | " x = grid*1.414 - center\n", 71 | " prob = np.prod(1/(2*np.pi/256.0)**0.5 * np.exp(-x**2/(2/256.0)),-1)\n", 72 | " color[:,0] += i/8.0 * prob\n", 73 | " color[:,2] += prob \n", 74 | " elif data =='checkerboard':\n", 75 | " l=[0,2,1,3,0,2,1,3]\n", 76 | " color = np.zeros((pixels,pixels,3))\n", 77 | " for i in range(8):\n", 78 | " y=i//2*256\n", 79 | " x=l[i]*256\n", 80 | " color[x:x+256, y:y+256,0]=i/8.0\n", 81 | " color[x:x+256, y:y+256,2]=1\n", 82 | " elif data=='2spirals':\n", 83 | " grid = grid.reshape((pixels,pixels,2))\n", 84 | " color = np.zeros((pixels,pixels,3))\n", 85 | " for i in range(10000):\n", 86 | " n = (i/10000.0)**0.5 * 540 * 2 *np.pi / 360\n", 87 | " d = np.zeros((1,2))\n", 88 | " d[0,0] = -np.cos(n) * n/3.0/8\n", 89 | " d[0,1] = np.sin(n) * n /3.0 /8\n", 90 | "\n", 91 | " idx = int((d[0,0]+limit)/step)\n", 92 | " idy = int((d[0,1]+limit)/step)\n", 93 | " x = grid[idx-50:idx+50, idy-50:idy+50,:] - d\n", 94 | " cur_prob = np.prod(1/(2*np.pi*0.01/64)**0.5 * np.exp(-x**2/(2*0.01/64)),-1)\n", 95 | "\n", 96 | " cur_color = np.ones((100,100,3))\n", 97 | " cur_color[:,:,0] = (i/20000.0+0.5)*cur_prob\n", 98 | " cur_color[:,:,2] = cur_prob\n", 99 | " color[idx-50:idx+50, idy-50:idy+50] += cur_color\n", 100 | "\n", 101 | " #other spiral\n", 102 | " idx = int((-d[0,0]+limit)/step)\n", 103 | " idy = int((-d[0,1]+limit)/step)\n", 104 | " x = grid[idx-50:idx+50, idy-50:idy+50,:] + d\n", 105 | " cur_prob = np.prod(1/(2*np.pi*0.01/64)**0.5 * np.exp(-x**2/(2*0.01/64)),-1)\n", 106 | " \n", 107 | " cur_color = np.ones((100,100,3))\n", 108 | " cur_color[:,:,0] = (-i/20000.0+0.5)*cur_prob\n", 109 | " cur_color[:,:,2] = cur_prob\n", 110 | " color[idx-50:idx+50, idy-50:idy+50] += cur_color\n", 111 | "\n", 112 | " \n", 113 | " color = color.reshape((pixels,pixels,3))\n", 114 | " color[:,:,0]/=(color[:,:,2]+1e-12)\n", 115 | " color[:,:,1]=1\n", 116 | " prob = color[:,:,2].reshape((pixels,pixels))\n", 117 | " prob = prob / np.sum(prob) #normalize the data\n", 118 | " prob+=1e-20\n", 119 | " entropy = - prob * np.log(prob)/np.log(2)\n", 120 | " entropy = np.sum(entropy)\n", 121 | " max_prob = np.max(prob)\n", 122 | "\n", 123 | " color[:,:,2]/=np.max(color[:,:,2])\n", 124 | " color[:,:,1]=color[:,:,2]\n", 125 | " color = np.clip(color, 0, 1)\n", 126 | " color = col.hsv_to_rgb(color)\n", 127 | "\n", 128 | "\n", 129 | " fig = plt.figure(figsize=(18, 18))\n", 130 | "\n", 131 | " ax1 = fig.add_subplot(1,2,1)\n", 132 | " ax1.axis('off')\n", 133 | " ax1.imshow(prob, extent=(-limit, limit, -limit, limit))\n", 134 | "\n", 135 | " ax2 = fig.add_subplot(1,2,2)\n", 136 | " ax2.axis('off')\n", 137 | " ax2.imshow(color, extent=(-limit, limit, -limit, limit))\n", 138 | "\n", 139 | " fig.tight_layout()\n", 140 | "\n", 141 | " return entropy-20, max_prob, prob, color\n", 142 | "\n", 143 | "entropy8g, max_prob8g, prob8g, color8g = plot_gt('8gaussians')\n", 144 | "print('Entropy for 8gaussians: {:f}'.format( entropy8g))\n", 145 | "#print('Max probability for 8gaussians: {:e}'.format(max_prob8g))\n", 146 | "\n", 147 | "entropyc, max_probc, probc, colorc =plot_gt('checkerboard')\n", 148 | "print('Entropy for Checkerboard: {:f}'.format( entropyc))\n", 149 | "#print('Max probability for Checkerboard: {:e}'.format(max_probc))\n", 150 | "\n", 151 | "entropy2s, max_prob2s, prob2s, color2s =plot_gt('2spirals')\n", 152 | "print('Entropy for 2spirals: {:f}'.format( entropy2s))\n", 153 | "#print('Max probability for 2spirals: {:e}'.format(max_prob2s))\n" 154 | ], 155 | "execution_count": null, 156 | "outputs": [ 157 | { 158 | "output_type": "stream", 159 | "text": [ 160 | "Entropy for 8gaussians: -1.916776\n" 161 | ], 162 | "name": "stdout" 163 | } 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": { 169 | "id": "C6B7eXoWMvz9", 170 | "colab_type": "text" 171 | }, 172 | "source": [ 173 | "We have plotted two figures for each dataset.\n", 174 | "On the left-hand side is the ground truth density function. On the right-hand side is a color map of the data that we will later use to visualize the latent space of a VAE. \n", 175 | "\n", 176 | "We have also printed out the approximate entropy *H* of the ground truth probability distribution over the pixels. Recall that for continuous data, *H + 2n* can be interpreted as the number of bits needed to describe a sample from a 2D distribution to *n*-bit accuracy. As a sanity check, a uniform distribution across the whole domain should have an entropy of 0 bits, and the checkerboard dataset, which is a uniform distribution over half the domain, would have an entropy of -1. The entropy is also the value of the optimal negative log-likelihood for a maximum likelihood model.\n", 177 | "\n", 178 | "Now let us set up the dataloader. We will call ```sample2d``` every iteration during training to sample a batch of continuous values from the ground truth dataset to train the VAE." 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "metadata": { 184 | "id": "0lJYTZD-L1QK", 185 | "colab_type": "code", 186 | "colab": {} 187 | }, 188 | "source": [ 189 | "def sample2d(data, batch_size=200):\n", 190 | " #code largely taken from https://github.com/nicola-decao/BNAF/blob/master/data/generate2d.py\n", 191 | "\n", 192 | " rng = np.random.RandomState()\n", 193 | "\n", 194 | " if data == '8gaussians':\n", 195 | " scale = 4\n", 196 | " centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)),\n", 197 | " (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2),\n", 198 | " 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))]\n", 199 | " centers = [(scale * x, scale * y) for x, y in centers]\n", 200 | "\n", 201 | " dataset = []\n", 202 | " #dataset = np.zeros((batch_size, 2))\n", 203 | " for i in range(batch_size):\n", 204 | " point = rng.randn(2) * 0.5\n", 205 | " idx = rng.randint(8)\n", 206 | " center = centers[idx]\n", 207 | " point[0] += center[0]\n", 208 | " point[1] += center[1]\n", 209 | " dataset.append(point)\n", 210 | " #dataset[i]=point\n", 211 | " dataset = np.array(dataset, dtype='float32')\n", 212 | " dataset /= 1.414\n", 213 | " return dataset/8.0\n", 214 | "\n", 215 | " elif data == '2spirals':\n", 216 | " n = np.sqrt(np.random.rand(batch_size, 1)) * 540 * (2 * np.pi) / 360\n", 217 | " d1x = -np.cos(n) * n\n", 218 | " d1y = np.sin(n) * n \n", 219 | " x = np.hstack((d1x, d1y)) / 3 * (np.random.randint(0, 2, (batch_size,1)) * 2 -1)\n", 220 | " x += np.random.randn(*x.shape) * 0.1\n", 221 | " return x/8.0\n", 222 | "\n", 223 | " elif data == 'checkerboard':\n", 224 | " x1 = np.random.rand(batch_size) * 4 - 2\n", 225 | " x2_ = np.random.rand(batch_size) - np.random.randint(0, 2, batch_size) * 2\n", 226 | " x2 = x2_ + (np.floor(x1) % 2)\n", 227 | " return np.concatenate([x1[:, None], x2[:, None]], 1) * 2 / 8.0\n", 228 | "\n", 229 | " else:\n", 230 | " raise RuntimeError\n" 231 | ], 232 | "execution_count": null, 233 | "outputs": [] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": { 238 | "id": "21LuJO5yMuH-", 239 | "colab_type": "text" 240 | }, 241 | "source": [ 242 | "Now that we have set up our data loader, we can construct a Typical VAE. The architectural backbone of our encoder and decoder will be a *DenseBlock*, which concatenates the output of each fully connected layer with its input." 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "metadata": { 248 | "id": "AVzntknpNqdW", 249 | "colab_type": "code", 250 | "colab": {} 251 | }, 252 | "source": [ 253 | "from torch import nn\n", 254 | "\n", 255 | "class DenseBlock(nn.Module):\n", 256 | " def __init__(self, input_dim, growth, depth):\n", 257 | " super(DenseBlock,self).__init__()\n", 258 | " ops=[]\n", 259 | " for i in range(depth):\n", 260 | " ops.append(nn.Sequential(nn.utils.weight_norm(nn.Linear(input_dim+i*growth, growth)), nn.ReLU() ) )\n", 261 | "\n", 262 | " self.ops = nn.ModuleList(ops)\n", 263 | "\n", 264 | " def forward(self,x):\n", 265 | " for op in self.ops:\n", 266 | " y = op(x)\n", 267 | " x = torch.cat([x,y],1)\n", 268 | " return x\n" 269 | ], 270 | "execution_count": null, 271 | "outputs": [] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": { 276 | "id": "T2l0BFnupIxc", 277 | "colab_type": "text" 278 | }, 279 | "source": [ 280 | "Note that we use Weight Norm as opposed to Batch Norm. The reasononing behind using Weight Norm is that Batch Norm introduces noise during training, which although tolerable for classification hurts our ability to precisely reconstruct the input. By using Weight Norm, any noise introduced in our VAE is counted towards the regularization loss. We can now construct our VAE architecture and define functions to performe inference. ```compute_negative_elbo``` computes the negative ELBO and will be used during training. ```importance_sampling``` approximates the exact negative log-likelihood and will be used during evaluation.\n", 281 | "\n", 282 | "During training, we use a trick called the *free bits objective*. A problem with optimizing VAEs is that the latent space initially contains 0 information, so a lot of noise is applied to the space. This noise can send the optimization landscape into a bad local minima and cause swings in behavior depending on the random seed. To stabilize training, the free bits objective introduces a hyperparameter $\\alpha$ such that the regularization loss is inactive for latent variables containing less than $\\alpha$ bits of information. This allows the VAE to quickly learn to store $\\alpha$ bits into each latent variable, after which training can stablely proceed with the standard negative ELBO objective. We want $\\alpha$ to be high enough so that the VAE enters a stable regime of training but do not want $\\alpha$ to be too high since the free bits objective is not the natural VAE objective. Moreover, if $\\alpha$ is too high, then $q_\\phi(\\mathbf{z}|\\mathbf{x})$ will very quickly have low variance, which may prevent exploration of the latent space. In our example we will use $\\alpha=0.05$." 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "metadata": { 288 | "id": "PwiaLXm4njnU", 289 | "colab_type": "code", 290 | "colab": {} 291 | }, 292 | "source": [ 293 | "import torch\n", 294 | "import math\n", 295 | "import torch.nn.functional as F\n", 296 | "\n", 297 | "class VAE(nn.Module):\n", 298 | " def __init__(self, latent_dim=2):\n", 299 | " super(VAE,self).__init__()\n", 300 | " #set up hyperparameters\n", 301 | " self.latent_dim = latent_dim\n", 302 | " growth=1024\n", 303 | " depth=6\n", 304 | "\n", 305 | " #define architecture\n", 306 | " encoder_dense = DenseBlock(2, growth,depth) \n", 307 | " encoder_linear = nn.utils.weight_norm(nn.Linear(2+growth*depth, self.latent_dim*2))\n", 308 | " self.encoder = nn.Sequential(encoder_dense, encoder_linear)\n", 309 | "\n", 310 | " decoder_dense = DenseBlock(self.latent_dim, growth,depth) \n", 311 | " decoder_linear = nn.utils.weight_norm(nn.Linear(self.latent_dim+growth*depth, 2*2))\n", 312 | " self.decoder = nn.Sequential(decoder_dense, decoder_linear)\n", 313 | "\n", 314 | " def encode(self,x):\n", 315 | " z_params = self.encoder(x)\n", 316 | " z_mu = z_params[:,:self.latent_dim]\n", 317 | " z_logvar = z_params[:,self.latent_dim:]\n", 318 | " return z_mu, z_logvar\n", 319 | "\n", 320 | " def decode(self,z):\n", 321 | " x_params = self.decoder(z)\n", 322 | " x_mu = x_params[:,:2]\n", 323 | " x_logvar = x_params[:,2:]\n", 324 | " return x_mu, x_logvar\n", 325 | "\n", 326 | " def reparameterize(self, mu, logvar):\n", 327 | " std = torch.exp(0.5*logvar)\n", 328 | " return mu + torch.cuda.FloatTensor(std.shape).normal_() * std\n", 329 | "\n", 330 | " def forward(self,x):\n", 331 | " z_mu, z_logvar = self.encode(x)\n", 332 | " z = self.reparameterize(z_mu, z_logvar)\n", 333 | " x_mu, x_logvar = self.decode(z)\n", 334 | " return x_mu,x_logvar, z_mu, z_logvar\n", 335 | "\n", 336 | "\n", 337 | " def compute_negative_elbo(self, x, freebits=0):\n", 338 | " x_mu, x_logvar, z_mu, z_logvar = self.forward(x)\n", 339 | " l_rec = -torch.sum(gaussian_log_prob(x, x_mu, x_logvar),1)\n", 340 | " l_reg = torch.sum(F.relu(self.compute_kld(z_mu, z_logvar)-freebits*math.log(2))+freebits*math.log(2),1)\n", 341 | " return l_rec + l_reg, l_rec, l_reg\n", 342 | "\n", 343 | " def compute_kld(self,z_mu, z_logvar):\n", 344 | " return 0.5*(z_mu**2 + torch.exp(z_logvar) - 1 - z_logvar)\n", 345 | "\n", 346 | " \n", 347 | " def importance_sampling(self, x, importance_samples=1):\n", 348 | " z_mu, z_logvar = self.encode(x)\n", 349 | "\n", 350 | " z_mu = z_mu.unsqueeze(1).repeat((1,importance_samples,1))\n", 351 | " z_mu = z_mu.reshape((-1, self.latent_dim))\n", 352 | " z_logvar = z_logvar.unsqueeze(1).repeat((1,importance_samples,1))\n", 353 | " z_logvar = z_logvar.reshape((-1, self.latent_dim))\n", 354 | " x = x.unsqueeze(1).repeat((1,importance_samples,1))\n", 355 | " x = x.reshape((-1,2))\n", 356 | "\n", 357 | " z = self.reparameterize(z_mu, z_logvar)\n", 358 | " x_mu, x_logvar = self.decode(z)\n", 359 | " x_mu = x_mu.reshape((-1,importance_samples,2))\n", 360 | " x_logvar = x_logvar.reshape((-1,importance_samples,2))\n", 361 | "\n", 362 | " x = x.reshape((-1,importance_samples,2))\n", 363 | " z = z.reshape((-1,importance_samples,self.latent_dim))\n", 364 | " z_mu = z_mu.reshape((-1,importance_samples,self.latent_dim))\n", 365 | " z_logvar = z_logvar.reshape((-1,importance_samples,self.latent_dim))\n", 366 | "\n", 367 | " logpxz = gaussian_log_prob(x, x_mu,x_logvar)\n", 368 | " logpz = gaussian_log_prob(z, torch.zeros(z.shape).cuda(), torch.zeros(z.shape).cuda())\n", 369 | " logqzx = gaussian_log_prob(z, z_mu, z_logvar)\n", 370 | "\n", 371 | " logprob = logpxz+logpz - logqzx\n", 372 | " logprob = torch.sum(logprob,2)\n", 373 | " logprob = log_mean_exp(logprob, 1)\n", 374 | "\n", 375 | " return -logprob\n", 376 | "\n", 377 | "def gaussian_log_prob(z, mu, logvar):\n", 378 | " return -0.5*(math.log(2*math.pi) + logvar + (z-mu)**2/torch.exp(logvar))\n", 379 | "\n", 380 | "def log_mean_exp(x,axis):\n", 381 | " m,_=torch.max(x,axis)\n", 382 | " m2,_ = torch.max(x,axis,keepdim=True)\n", 383 | " return m + torch.log(torch.mean(torch.exp(x-m2),axis))\n", 384 | "\n" 385 | ], 386 | "execution_count": null, 387 | "outputs": [] 388 | }, 389 | { 390 | "cell_type": "markdown", 391 | "metadata": { 392 | "id": "glz3iGicAwWh", 393 | "colab_type": "text" 394 | }, 395 | "source": [ 396 | "We now have all the tools we need to train a VAE. We will sample 200 points each iteration and minimize the negative ELBO over the course of 60000 iterations. Each model should take roughly 20 minutes to train." 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "metadata": { 402 | "id": "kDaTSfbIBWUu", 403 | "colab_type": "code", 404 | "colab": {} 405 | }, 406 | "source": [ 407 | "import time\n", 408 | "\n", 409 | "def s2hms(s):\n", 410 | " h = s//3600\n", 411 | " m = (s-h*3600)//60\n", 412 | " s = int(s-h*3600-m*60)\n", 413 | " return h,m,s\n", 414 | "\n", 415 | "def print_progress(time, cur_iter, total_iter):\n", 416 | " h,m,s = s2hms(time)\n", 417 | " h2,m2,s2 = s2hms(time*total_iter/cur_iter - time)\n", 418 | " print('Time Elapsed: %d hours %d minutes %d seconds. Time Remaining: %d hours %d minutes %d seconds.'%(h,m,s,h2,m2,s2))\n", 419 | "\n", 420 | "\n", 421 | "def train_vae(dataset, model=None, epochs=60000, print_freq=1000):\n", 422 | " if model is None:\n", 423 | " model = VAE().cuda()\n", 424 | " optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, amsgrad=True)\n", 425 | " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5,\n", 426 | " patience=epochs//20,\n", 427 | " min_lr=1e-8, verbose=True,\n", 428 | " threshold_mode='abs')\n", 429 | " start=time.time()\n", 430 | " loss_ema=0\n", 431 | " best_ema =1e9\n", 432 | " for iteration in range(epochs): #train for 60k iterations\n", 433 | " data = torch.tensor(sample2d(dataset,40000)).float().cuda()\n", 434 | " neg_elbo, l_rec, l_reg = model.compute_negative_elbo(data,freebits=0.05)\n", 435 | "\n", 436 | " loss = torch.mean(l_reg + l_rec )/math.log(2)\n", 437 | " loss.backward()\n", 438 | " optimizer.step()\n", 439 | " optimizer.zero_grad()\n", 440 | " loss_ema = 0.999*loss_ema + 0.001*loss\n", 441 | " data=None\n", 442 | " #scheduler.step(loss_ema)\n", 443 | " if iteration == int(epochs*0.6) or iteration == int(epochs*0.7) or iteration == int(epochs*0.8) or iteration == int(epochs*0.9):\n", 444 | " for param_group in optimizer.param_groups:\n", 445 | " param_group['lr'] /= 2\n", 446 | " if iteration%print_freq == 0:\n", 447 | " with torch.no_grad():\n", 448 | " #print('Iteration %d. Loss: %f'%(iteration, loss))\n", 449 | " #neg_elbo, l_rec, l_reg = model.compute_negative_elbo(data,0)\n", 450 | " print('Iteration %d. EMA: %f ELBO: %f L_rec: %f L_reg: %f'%(iteration, loss_ema, torch.mean(neg_elbo)/math.log(2),torch.mean(l_rec)/math.log(2), torch.mean(l_reg)/math.log(2)))\n", 451 | " print_progress(time.time()-start, iteration+1, epochs)\n", 452 | "\n", 453 | " return model" 454 | ], 455 | "execution_count": null, 456 | "outputs": [] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "metadata": { 461 | "id": "HSY9QsfOELDV", 462 | "colab_type": "code", 463 | "colab": {} 464 | }, 465 | "source": [ 466 | "model8g = train_vae('8gaussians',None,60000,100) #should take ~20 minutes to train" 467 | ], 468 | "execution_count": null, 469 | "outputs": [] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "metadata": { 474 | "id": "N68TnDSMbvCH", 475 | "colab_type": "code", 476 | "colab": {} 477 | }, 478 | "source": [ 479 | "modelc = train_vae('checkerboard',None,60000,100) #should take ~20 minutes to train" 480 | ], 481 | "execution_count": null, 482 | "outputs": [] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "metadata": { 487 | "id": "KqzRmUbOby9i", 488 | "colab_type": "code", 489 | "colab": {} 490 | }, 491 | "source": [ 492 | "model2s = train_vae('2spirals',None,60000,100) #should take ~20 minutes to train" 493 | ], 494 | "execution_count": null, 495 | "outputs": [] 496 | }, 497 | { 498 | "cell_type": "markdown", 499 | "metadata": { 500 | "id": "LLZBNVuzuGSr", 501 | "colab_type": "text" 502 | }, 503 | "source": [ 504 | "We see that the reconstruction loss decreases as training progresses while the regularization loss increases.\n", 505 | "Now that we have trained a model, we can evaluate it by comparing the negative log-likelihood with the ground-truth distribution and by plotting the density map of the model. The negative log-likelihood is approximated with importance sampling." 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "metadata": { 511 | "id": "HHVycHB2utM5", 512 | "colab_type": "code", 513 | "colab": {} 514 | }, 515 | "source": [ 516 | "def plot_density2d(model, max_prob, gt_prob, importance_samples=1):\n", 517 | " #code largely taken from https://github.com/nicola-decao/BNAF/blob/master/toy2d.py\n", 518 | " limit=0.5\n", 519 | " step=1/1024.0\n", 520 | " grid = torch.Tensor([[a, b] for a in np.arange(-limit, limit, step) for b in np.arange(-limit, limit, step)])\n", 521 | " grid_dataset = torch.utils.data.TensorDataset(grid.cuda())\n", 522 | " grid_data_loader = torch.utils.data.DataLoader(grid_dataset, batch_size=20000//importance_samples, shuffle=False)\n", 523 | "\n", 524 | " l=[]\n", 525 | " start=time.time()\n", 526 | " with torch.no_grad():\n", 527 | " for idx, (x_mb,) in enumerate(grid_data_loader):\n", 528 | " temp= model.importance_sampling(x_mb,importance_samples)\n", 529 | " l.append(torch.exp(-temp))\n", 530 | " if idx % 600 == 0 and idx>0:\n", 531 | " print_progress(time.time()-start, idx, len(grid_data_loader))\n", 532 | " prob = torch.cat(l, 0)\n", 533 | " \n", 534 | " prob = prob.view(int(2 * limit / step), int(2 * limit / step))\n", 535 | " prob[prob!=prob]=0 #set nan probabilities to 0\n", 536 | "\n", 537 | " prob+=1e-20\n", 538 | " prob = prob/1024/1024\n", 539 | " nll = - gt_prob * np.log(prob.cpu().data.numpy())/np.log(2)\n", 540 | " nll = np.sum(nll)\n", 541 | " print('Negative Log Likelihood' , nll-20)\n", 542 | "\n", 543 | " prob /= torch.sum(prob)\n", 544 | " prob = prob.clamp(max=max_prob)\n", 545 | "\n", 546 | " prob = prob.cpu().data.numpy()\n", 547 | "\n", 548 | " fig = plt.figure(figsize=(18, 18))\n", 549 | " \n", 550 | " ax1 = fig.add_subplot(1,2,1)\n", 551 | " ax1.axis('off')\n", 552 | " ax1.imshow(gt_prob, extent=(-limit, limit, -limit, limit))\n", 553 | "\n", 554 | " ax2 = fig.add_subplot(1,2,2)\n", 555 | " ax2.axis('off')\n", 556 | " ax2.imshow(prob, extent=(-limit, limit, -limit, limit))\n", 557 | "\n", 558 | " fig.tight_layout()" 559 | ], 560 | "execution_count": null, 561 | "outputs": [] 562 | }, 563 | { 564 | "cell_type": "markdown", 565 | "metadata": { 566 | "id": "vkODCuaCzqBw", 567 | "colab_type": "text" 568 | }, 569 | "source": [ 570 | "Let us first quickly evaluate our models using 1 importance sample." 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "metadata": { 576 | "id": "avKEPoK209lK", 577 | "colab_type": "code", 578 | "colab": {} 579 | }, 580 | "source": [ 581 | "plot_density2d(model8g, max_prob8g,prob8g, 1)\n" 582 | ], 583 | "execution_count": null, 584 | "outputs": [] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "metadata": { 589 | "id": "6eW42o81b3g4", 590 | "colab_type": "code", 591 | "colab": {} 592 | }, 593 | "source": [ 594 | "plot_density2d(modelc, max_probc,probc, 1)\n" 595 | ], 596 | "execution_count": null, 597 | "outputs": [] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "metadata": { 602 | "id": "s0Bzl5Z3b7wH", 603 | "colab_type": "code", 604 | "colab": {} 605 | }, 606 | "source": [ 607 | "plot_density2d(model2s, max_prob2s,prob2s, 1)" 608 | ], 609 | "execution_count": null, 610 | "outputs": [] 611 | }, 612 | { 613 | "cell_type": "markdown", 614 | "metadata": { 615 | "id": "ki9ZZE65zSUX", 616 | "colab_type": "text" 617 | }, 618 | "source": [ 619 | "When we use only 1 importance sample, then the result is essentially the negative ELBO. We will now better approximate the likelihood by taking 250 importance samples, which will take 250 times longer. This provides a more accurate evaluation of the model performance and gives us an idea of how tight of a bound the negative ELBO is." 620 | ] 621 | }, 622 | { 623 | "cell_type": "code", 624 | "metadata": { 625 | "id": "61CNugwGDDy7", 626 | "colab_type": "code", 627 | "colab": {} 628 | }, 629 | "source": [ 630 | "plot_density2d(model8g, max_prob8g,prob8g, 250) #should take around 20 minutes" 631 | ], 632 | "execution_count": null, 633 | "outputs": [] 634 | }, 635 | { 636 | "cell_type": "code", 637 | "metadata": { 638 | "id": "3n-fohH6fkNp", 639 | "colab_type": "code", 640 | "colab": {} 641 | }, 642 | "source": [ 643 | "plot_density2d(modelc, max_probc,probc, 250) #should take around 20 minutes" 644 | ], 645 | "execution_count": null, 646 | "outputs": [] 647 | }, 648 | { 649 | "cell_type": "code", 650 | "metadata": { 651 | "id": "-X2N6-G1fnjg", 652 | "colab_type": "code", 653 | "colab": {} 654 | }, 655 | "source": [ 656 | "plot_density2d(model2s, max_prob2s,prob2s, 250) #should take around 20 minutes" 657 | ], 658 | "execution_count": null, 659 | "outputs": [] 660 | }, 661 | { 662 | "cell_type": "markdown", 663 | "metadata": { 664 | "id": "zDYelNNFxTJL", 665 | "colab_type": "text" 666 | }, 667 | "source": [ 668 | "We can also qualitatively understand the unsupervised representation learning abilities of the VAE by visualizing the correspondence between the latent space and the color maps of the ground truth distribution. We also print out the probability mass of low-density ``dark pixels\" to quantitatively evaluate how ``filled\" the latent space is.\n" 669 | ] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "metadata": { 674 | "id": "5rKGQMRfydzb", 675 | "colab_type": "code", 676 | "colab": {} 677 | }, 678 | "source": [ 679 | "def visualize_latent(model, colormap_gt):\n", 680 | " limit=2\n", 681 | " step=1/256.0\n", 682 | " grid = torch.Tensor([[a, b] for a in np.arange(-limit, limit, step) for b in np.arange(-limit, limit, step)])\n", 683 | " grid_dataset = torch.utils.data.TensorDataset(grid.cuda())\n", 684 | " grid_data_loader = torch.utils.data.DataLoader(grid_dataset, batch_size=20000, shuffle=False)\n", 685 | " colormap_gt = colormap_gt.reshape((1024*1024,3))\n", 686 | " l = []\n", 687 | " with torch.no_grad():\n", 688 | " for z ,in grid_data_loader:\n", 689 | " x,_ = model.decode(z) #find the value that each latent vector maps to\n", 690 | " l.append(x)\n", 691 | " x = torch.cat(l,0)\n", 692 | " x=(x*1024+512).long() #find the corresponding pixel in the color map \n", 693 | " x[x<0]=0\n", 694 | " x[x>1023]=1023\n", 695 | " y = x[:,0]*1024+x[:,1]\n", 696 | " y = y.reshape((-1,1)).repeat((1,3))\n", 697 | "\n", 698 | " colormap_pred = torch.gather(torch.tensor(colormap_gt).cuda(), 0,y)\n", 699 | " pz = torch.exp(torch.sum(gaussian_log_prob(grid.cuda(), torch.zeros(grid.shape).cuda(), torch.zeros(grid.shape).cuda()),-1))\n", 700 | " pz/=torch.sum(pz)\n", 701 | " color = col.rgb_to_hsv(colormap_pred.cpu().numpy())[:,2]\n", 702 | " print('Dark Pixels', torch.sum(pz*(torch.tensor(color).cuda()<0.01).float()).item())\n", 703 | "\n", 704 | "\n", 705 | " colormap_pred = colormap_pred.reshape((1024,1024,3)).cpu().data.numpy()\n", 706 | " colormap_gt = colormap_gt.reshape((1024,1024,3))\n", 707 | "\n", 708 | "\n", 709 | "\n", 710 | "\n", 711 | " fig = plt.figure(figsize=(18, 18))\n", 712 | " \n", 713 | " ax1 = fig.add_subplot(1,2,1)\n", 714 | " ax1.axis('off')\n", 715 | " ax1.imshow(colormap_gt, extent=(-0.5, 0.5, -0.5, 0.5))\n", 716 | "\n", 717 | " ax2 = fig.add_subplot(1,2,2)\n", 718 | " ax2.axis('off')\n", 719 | " ax2.imshow(colormap_pred, extent=(-limit, limit, -limit, limit))\n", 720 | "\n", 721 | " fig.tight_layout()\n", 722 | "\n" 723 | ], 724 | "execution_count": null, 725 | "outputs": [] 726 | }, 727 | { 728 | "cell_type": "code", 729 | "metadata": { 730 | "id": "B6I-CPRHFb0J", 731 | "colab_type": "code", 732 | "colab": {} 733 | }, 734 | "source": [ 735 | "visualize_latent(model8g,color8g)" 736 | ], 737 | "execution_count": null, 738 | "outputs": [] 739 | }, 740 | { 741 | "cell_type": "code", 742 | "metadata": { 743 | "id": "5ISzcTrcfAtE", 744 | "colab_type": "code", 745 | "colab": {} 746 | }, 747 | "source": [ 748 | "visualize_latent(modelc,colorc)" 749 | ], 750 | "execution_count": null, 751 | "outputs": [] 752 | }, 753 | { 754 | "cell_type": "code", 755 | "metadata": { 756 | "id": "WKqFQT4jfDEy", 757 | "colab_type": "code", 758 | "colab": {} 759 | }, 760 | "source": [ 761 | "visualize_latent(model2s,color2s)" 762 | ], 763 | "execution_count": null, 764 | "outputs": [] 765 | }, 766 | { 767 | "cell_type": "markdown", 768 | "metadata": { 769 | "id": "tz7oJ5vejQeR", 770 | "colab_type": "text" 771 | }, 772 | "source": [ 773 | "##IWAEs\n", 774 | "We now switch out the negative ELBO for an approximation of the negative log likelihood using importance sampling." 775 | ] 776 | }, 777 | { 778 | "cell_type": "code", 779 | "metadata": { 780 | "id": "ITr9uPZAGdRT", 781 | "colab_type": "code", 782 | "colab": {} 783 | }, 784 | "source": [ 785 | "def train_iwae(dataset, model=None, epochs=60000, print_freq=1000):\n", 786 | " if model is None:\n", 787 | " model = VAE().cuda()\n", 788 | " optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, amsgrad=True)\n", 789 | " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5,\n", 790 | " patience=epochs/20,\n", 791 | " min_lr=1e-8, verbose=True,\n", 792 | " threshold_mode='abs')\n", 793 | " start=time.time()\n", 794 | " loss_ema=0\n", 795 | " for iteration in range(epochs): #train for 60k iterations\n", 796 | " data = torch.tensor(sample2d(dataset,200)).float().cuda()\n", 797 | " nll = model.importance_sampling(data,10)\n", 798 | " loss = torch.mean(nll)/math.log(2)\n", 799 | " loss.backward()\n", 800 | " optimizer.step()\n", 801 | " optimizer.zero_grad()\n", 802 | " loss_ema = 0.999*loss_ema + 0.001*loss\n", 803 | " #scheduler.step(loss_ema)\n", 804 | " if iteration == int(epochs*0.6) or iteration == int(epochs*0.7) or iteration == int(epochs*0.8) or iteration == int(epochs*0.9):\n", 805 | " for param_group in optimizer.param_groups:\n", 806 | " param_group['lr'] /= 2\n", 807 | " if iteration%print_freq == 0:\n", 808 | " with torch.no_grad():\n", 809 | " print('Iteration %d. Loss: %f'%(iteration, loss))\n", 810 | " neg_elbo, l_rec, l_reg = model.compute_negative_elbo(data,0)\n", 811 | " print('Iteration %d. EMA: %f ELBO: %f L_rec: %f L_reg: %f'%(iteration, loss_ema, torch.mean(neg_elbo)/math.log(2),torch.mean(l_rec)/math.log(2), torch.mean(l_reg)/math.log(2)))\n", 812 | " print_progress(time.time()-start, iteration+1, epochs)\n", 813 | "\n", 814 | " return model" 815 | ], 816 | "execution_count": null, 817 | "outputs": [] 818 | }, 819 | { 820 | "cell_type": "code", 821 | "metadata": { 822 | "id": "6NbdjJV8GUKY", 823 | "colab_type": "code", 824 | "colab": {} 825 | }, 826 | "source": [ 827 | "model8g_iwae = train_iwae('8gaussians', None, 30000,400) #should take ~20 minutes to train" 828 | ], 829 | "execution_count": null, 830 | "outputs": [] 831 | }, 832 | { 833 | "cell_type": "code", 834 | "metadata": { 835 | "id": "mI--iqyMfydF", 836 | "colab_type": "code", 837 | "colab": {} 838 | }, 839 | "source": [ 840 | "modelc_iwae = train_iwae('checkerboard', None, 30000,400) #should take ~20 minutes to train" 841 | ], 842 | "execution_count": null, 843 | "outputs": [] 844 | }, 845 | { 846 | "cell_type": "code", 847 | "metadata": { 848 | "id": "uurxajhhJOsU", 849 | "colab_type": "code", 850 | "colab": {} 851 | }, 852 | "source": [ 853 | "model2s_iwae = train_iwae('2spirals', None, 30000,400) #should take ~20 minutes to train" 854 | ], 855 | "execution_count": null, 856 | "outputs": [] 857 | }, 858 | { 859 | "cell_type": "markdown", 860 | "metadata": { 861 | "id": "cBrAwBPBjg0g", 862 | "colab_type": "text" 863 | }, 864 | "source": [ 865 | "We see that compared to a Typical VAE, $D_{KL}(q_\\phi(\\mathbf{z}|\\mathbf{x}), p_\\theta(\\mathbf{z}))$ is lower, indicating that the variance of $q_\\phi(\\mathbf{z}|\\mathbf{x})$ is higher than in a VAE. We now evaluate our model and check how accurate of an approximation our objective was to the true negative log likelihood." 866 | ] 867 | }, 868 | { 869 | "cell_type": "code", 870 | "metadata": { 871 | "id": "whYqBepvMZrt", 872 | "colab_type": "code", 873 | "colab": {} 874 | }, 875 | "source": [ 876 | "plot_density2d(model8g_iwae, max_prob8g,prob8g ,10)" 877 | ], 878 | "execution_count": null, 879 | "outputs": [] 880 | }, 881 | { 882 | "cell_type": "code", 883 | "metadata": { 884 | "id": "OwxTRZ25gFGG", 885 | "colab_type": "code", 886 | "colab": {} 887 | }, 888 | "source": [ 889 | "plot_density2d(modelc_iwae, max_probc,probc ,10)" 890 | ], 891 | "execution_count": null, 892 | "outputs": [] 893 | }, 894 | { 895 | "cell_type": "code", 896 | "metadata": { 897 | "id": "dRhQtzSogGPb", 898 | "colab_type": "code", 899 | "colab": {} 900 | }, 901 | "source": [ 902 | "plot_density2d(model2s_iwae, max_prob2s,prob2s ,10)" 903 | ], 904 | "execution_count": null, 905 | "outputs": [] 906 | }, 907 | { 908 | "cell_type": "code", 909 | "metadata": { 910 | "id": "sAtCDn3ThXnJ", 911 | "colab_type": "code", 912 | "colab": {} 913 | }, 914 | "source": [ 915 | "plot_density2d(model8g_iwae, max_prob8g,prob8g ,250) #should take ~20 minutes" 916 | ], 917 | "execution_count": null, 918 | "outputs": [] 919 | }, 920 | { 921 | "cell_type": "code", 922 | "metadata": { 923 | "id": "7ld8BN7ihZjD", 924 | "colab_type": "code", 925 | "colab": {} 926 | }, 927 | "source": [ 928 | "plot_density2d(modelc_iwae, max_probc,probc ,250) #should take ~20 minutes" 929 | ], 930 | "execution_count": null, 931 | "outputs": [] 932 | }, 933 | { 934 | "cell_type": "code", 935 | "metadata": { 936 | "id": "N5GonnEOhaB-", 937 | "colab_type": "code", 938 | "colab": {} 939 | }, 940 | "source": [ 941 | "plot_density2d(model2s_iwae, max_prob2s,prob2s ,250) #should take ~20 minutes" 942 | ], 943 | "execution_count": null, 944 | "outputs": [] 945 | }, 946 | { 947 | "cell_type": "markdown", 948 | "metadata": { 949 | "id": "Tsi7H3uKjrP1", 950 | "colab_type": "text" 951 | }, 952 | "source": [ 953 | "We now visualize the latent space of IWAE. We see that compared to Typical VAEs much less of the latent space is mapped to low-density inputs." 954 | ] 955 | }, 956 | { 957 | "cell_type": "code", 958 | "metadata": { 959 | "id": "KIozAKHUMflG", 960 | "colab_type": "code", 961 | "colab": {} 962 | }, 963 | "source": [ 964 | "visualize_latent(model8g_iwae,color8g)" 965 | ], 966 | "execution_count": null, 967 | "outputs": [] 968 | }, 969 | { 970 | "cell_type": "code", 971 | "metadata": { 972 | "id": "Hb3mpF6GgSP8", 973 | "colab_type": "code", 974 | "colab": {} 975 | }, 976 | "source": [ 977 | "visualize_latent(modelc_iwae,colorc)" 978 | ], 979 | "execution_count": null, 980 | "outputs": [] 981 | }, 982 | { 983 | "cell_type": "code", 984 | "metadata": { 985 | "id": "FYP4l--_gS0K", 986 | "colab_type": "code", 987 | "colab": {} 988 | }, 989 | "source": [ 990 | "visualize_latent(model2s_iwae,color2s)" 991 | ], 992 | "execution_count": null, 993 | "outputs": [] 994 | }, 995 | { 996 | "cell_type": "code", 997 | "metadata": { 998 | "id": "rRowqVPs9ONg", 999 | "colab_type": "code", 1000 | "colab": {} 1001 | }, 1002 | "source": [ 1003 | "print('Marginal KL',calc_marginal_kl(model8g_iwae, '8gaussians')) #should take ~1 minute" 1004 | ], 1005 | "execution_count": null, 1006 | "outputs": [] 1007 | }, 1008 | { 1009 | "cell_type": "code", 1010 | "metadata": { 1011 | "id": "KtwZzDvs9Q3A", 1012 | "colab_type": "code", 1013 | "colab": {} 1014 | }, 1015 | "source": [ 1016 | "print('Marginal KL',calc_marginal_kl(modelc_iwae, 'checkerboard'))" 1017 | ], 1018 | "execution_count": null, 1019 | "outputs": [] 1020 | }, 1021 | { 1022 | "cell_type": "code", 1023 | "metadata": { 1024 | "id": "87tJNoMb9RH4", 1025 | "colab_type": "code", 1026 | "colab": {} 1027 | }, 1028 | "source": [ 1029 | "print('Marginal KL',calc_marginal_kl(model2s_iwae, '2spirals'))" 1030 | ], 1031 | "execution_count": null, 1032 | "outputs": [] 1033 | } 1034 | ] 1035 | } --------------------------------------------------------------------------------