├── .gitignore ├── README.md ├── requirements.txt └── src ├── baseline.py ├── cvae.ipynb ├── cvae.py ├── main.py ├── mnist.py ├── util.py └── video.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | src/.ipynb_checkpoints 3 | src/__pycache__ 4 | data/ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conditional Variational Auto-encoder 2 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-structured-output-representation/structured-prediction-on-mnist)](https://paperswithcode.com/sota/structured-prediction-on-mnist?p=learning-structured-output-representation) 3 | 4 | ## Introduction 5 | 6 | This tutorial implements [Learning Structured Output Representation using Deep Conditional Generative Models](http://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-generati) paper, which introduced Conditional Variational Auto-encoders in 2015, using Pyro PPL. 7 | 8 | Supervised deep learning has been successfully applied for many recognition problems in machine learning and computer vision. 9 | Although it can approximate a complex many-to-one function very well when large number of training data is provided, the lack of probabilistic inference of the current supervised deep learning methods makes it difficult to model a complex structured output representations. 10 | In this work, Kihyuk Sohn, Honglak Lee and Xinchen Yan develop a scalable deep conditional generative model for structured output variables using Gaussian latent variables. 11 | The model is trained efficiently in the framework of stochastic gradient variational Bayes, and allows a fast prediction using stochastic feed-forward inference. 12 | They called the model Conditional Variational Auto-encoder (CVAE). 13 | 14 | The CVAE is a conditional directed graphical model whose input observations modulate the prior on Gaussian latent variables that generate the outputs. 15 | It is trained to maximize the conditional marginal log-likelihood. 16 | The authors formulate the variational learning objective of the CVAE in the framework of stochastic gradient variational Bayes (SGVB). 17 | In experiments, they demonstrate the effectiveness of the CVAE in comparison to the deterministic neural network counterparts in generating diverse but realistic output predictions using stochastic inference. 18 | Here, we will implement their proof of concept: an artificial experimental setting for structured output prediction using MNIST database. 19 | 20 | ## The problem 21 | Let's divide each digit image into four quadrants, and take one, two, or three quadrant(s) as an input and the remaining quadrants as an output to be predicted. 22 | The image below shows the case where one quadrant is the input: 23 | 24 | image1 25 | 26 | Our objective is to **learn a model that can perform probabilistic inference and make diverse predictions from a single input**. 27 | This is because we are not simply modeling a many-to-one function as in classification tasks, but we may need to model a mapping from single input to many possible outputs. One of the limitations of deterministic neural networks is that they generate only a single prediction. 28 | In the example above, the input shows a small part of a digit that might be a three or a five. 29 | 30 | ## Evaluating the results 31 | For qualitative analysis, we visualize the generated output samples in the next figure. As we can see, the baseline NNs can only make a single deterministic prediction, and as a result the output looks blurry and doesn’t look realistic in many cases. In contrast, the samples generated by the CVAE models are more realistic and diverse in shape; sometimes they can even change their identity (digit labels), such as from 3 to 5 or from 4 to 9, and vice versa. 32 | 33 | image1 34 | 35 | We also provide a quantitative evidence by estimating the marginal conditional log-likelihoods (CLLs) in next table. 36 | 37 | | | 1 quadrant | 2 quadrants | 3 quadrants | 38 | |--------------------|------------|-------------|-------------| 39 | | NN (baseline) | 100.4 | 61.9 | 25.4 | 40 | | CVAE (Monte Carlo) | 71.8 | 51.0 | 24.2 | 41 | | Performance gap | 28.6 | 10.9 | 1.2 | 42 | 43 | We achieved similar results to the ones achieved by the authors in the paper. We trained only for 50 epochs with early stopping patience of 3 epochs; to improve the results, we could leave the algorithm training for longer. Nevertheless, we can observe the same effect shown in the paper: **the estimated CLLs of the CVAE significantly outperforms the baseline NN**. 44 | 45 | See the full code on [Github](https://github.com/ucals/cvae). 46 | 47 | ## IMPORTANT 48 | There are some issue reports when trying to run the code with Pyro versions different than the one in `requirements.txt`. 49 | So, to make sure the code works, the recommended way is to create a clean virtual environment (conda or virtualenv), and running `pip install -r requirements.txt` in this new environment. 50 | 51 | ## References 52 | 53 | [1] `Learning Structured Output Representation using Deep Conditional Generative Models`,
     54 | Kihyuk Sohn, Xinchen Yan, Honglak Lee 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyro-ppl==1.4.0 2 | pandas==1.5.3 3 | torch==1.13.1 4 | torchvision==0.14.1 5 | matplotlib==3.7.1 6 | scikit-learn==1.2.2 7 | -------------------------------------------------------------------------------- /src/baseline.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | from pathlib import Path 4 | from tqdm import tqdm 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BaselineNet(nn.Module): 11 | def __init__(self, hidden_1, hidden_2): 12 | super().__init__() 13 | self.fc1 = nn.Linear(784, hidden_1) 14 | self.fc2 = nn.Linear(hidden_1, hidden_2) 15 | self.fc3 = nn.Linear(hidden_2, 784) 16 | self.relu = nn.ReLU() 17 | 18 | def forward(self, x): 19 | x = x.view(-1, 784) 20 | hidden = self.relu(self.fc1(x)) 21 | hidden = self.relu(self.fc2(hidden)) 22 | y = torch.sigmoid(self.fc3(hidden)) 23 | return y 24 | 25 | 26 | class MaskedBCELoss(nn.Module): 27 | def __init__(self, masked_with=-1): 28 | super().__init__() 29 | self.masked_with = masked_with 30 | 31 | def forward(self, input, target): 32 | target = target.view(input.shape) 33 | loss = F.binary_cross_entropy(input, target, reduction='none') 34 | loss[target == self.masked_with] = 0 35 | return loss.sum() 36 | 37 | 38 | def train(device, dataloaders, dataset_sizes, learning_rate, num_epochs, 39 | early_stop_patience, model_path): 40 | 41 | # Train baseline 42 | baseline_net = BaselineNet(500, 500) 43 | baseline_net.to(device) 44 | optimizer = torch.optim.Adam(baseline_net.parameters(), lr=learning_rate) 45 | criterion = MaskedBCELoss() 46 | best_loss = np.inf 47 | early_stop_count = 0 48 | 49 | for epoch in range(num_epochs): 50 | for phase in ['train', 'val']: 51 | if phase == 'train': 52 | baseline_net.train() 53 | else: 54 | baseline_net.eval() 55 | 56 | running_loss = 0.0 57 | num_preds = 0 58 | 59 | bar = tqdm(dataloaders[phase], 60 | desc='NN Epoch {} {}'.format(epoch, phase).ljust(20)) 61 | for i, batch in enumerate(bar): 62 | inputs = batch['input'].to(device) 63 | outputs = batch['output'].to(device) 64 | 65 | optimizer.zero_grad() 66 | 67 | with torch.set_grad_enabled(phase == 'train'): 68 | preds = baseline_net(inputs) 69 | loss = criterion(preds, outputs) / inputs.size(0) 70 | if phase == 'train': 71 | loss.backward() 72 | optimizer.step() 73 | 74 | running_loss += loss.item() 75 | num_preds += 1 76 | if i % 10 == 0: 77 | bar.set_postfix(loss='{:.2f}'.format(running_loss / num_preds), 78 | early_stop_count=early_stop_count) 79 | 80 | epoch_loss = running_loss / dataset_sizes[phase] 81 | # deep copy the model 82 | if phase == 'val': 83 | if epoch_loss < best_loss: 84 | best_loss = epoch_loss 85 | best_model_wts = copy.deepcopy(baseline_net.state_dict()) 86 | early_stop_count = 0 87 | else: 88 | early_stop_count += 1 89 | 90 | if early_stop_count >= early_stop_patience: 91 | break 92 | 93 | baseline_net.load_state_dict(best_model_wts) 94 | baseline_net.eval() 95 | 96 | # Save model weights 97 | Path(model_path).parent.mkdir(parents=True, exist_ok=True) 98 | torch.save(baseline_net.state_dict(), model_path) 99 | 100 | return baseline_net 101 | -------------------------------------------------------------------------------- /src/cvae.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Conditional Variational Auto-encoder\n", 8 | "\n", 9 | "## Introduction\n", 10 | "This tutorial implements [Learning Structured Output Representation using Deep Conditional Generative Models](http://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-generati) paper, which introduced Conditional Variational Auto-encoders in 2015, using Pyro PPL.\n", 11 | "\n", 12 | "Supervised deep learning has been successfully applied for many recognition problems in machine learning and computer vision. Although it can approximate a complex many-to-one function very well when large number of training data is provided, the lack of probabilistic inference of the current supervised deep learning methods makes it difficult to model a complex structured output representations. In this work, Kihyuk Sohn, Honglak Lee and Xinchen Yan develop a scalable deep conditional generative model for structured output variables using Gaussian latent variables. The model is trained efficiently in the framework of stochastic gradient variational Bayes, and allows a fast prediction using stochastic feed-forward inference. They called the model Conditional Variational Auto-encoder (CVAE).\n", 13 | "\n", 14 | "The CVAE is a conditional directed graphical model whose input observations modulate the prior on Gaussian latent variables that generate the outputs. It is trained to maximize the conditional marginal log-likelihood. The authors formulate the variational learning objective of the CVAE in the framework of stochastic gradient variational Bayes (SGVB). In experiments, they demonstrate the effectiveness of the CVAE in comparison to the deterministic neural network counterparts in generating diverse but realistic output predictions using stochastic inference. Here, we will implement their proof of concept: an artificial experimental setting for structured output prediction using MNIST database.\n", 15 | "\n", 16 | "## The problem\n", 17 | "Let's divide each digit image into four quadrants, and take one, two, or three quadrant(s) as an input and the remaining quadrants as an output to be predicted. The image below shows the case where one quadrant is the input:\n", 18 | "\n", 19 | "\"image1\"\n", 20 | "\n", 21 | "Our objective is to **learn a model that can perform probabilistic inference and make diverse predictions from a single input**. This is because we are not simply modeling a many-to-one function as in classification tasks, but we may need to model a mapping from single input to many possible outputs. One of the limitations of deterministic neural networks is that they generate only a single prediction. In the example above, the input shows a small part of a digit that might be a three or a five. \n", 22 | "\n", 23 | "## Preparing the data\n", 24 | "We use the MNIST dataset; the first step is to prepare it. Depending on how many quadrants we will use as inputs, we will build the datasets and dataloaders, removing the unused pixels with -1:" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "class CVAEMNIST(Dataset):\n", 34 | " def __init__(self, root, train=True, transform=None, download=False):\n", 35 | " self.original = MNIST(root, train=train, download=download)\n", 36 | " self.transform = transform\n", 37 | "\n", 38 | " def __len__(self):\n", 39 | " return len(self.original)\n", 40 | "\n", 41 | " def __getitem__(self, item):\n", 42 | " image, digit = self.original[item]\n", 43 | " sample = {'original': image, 'digit': digit}\n", 44 | " if self.transform:\n", 45 | " sample = self.transform(sample)\n", 46 | "\n", 47 | " return sample\n", 48 | "\n", 49 | "\n", 50 | "class ToTensor:\n", 51 | " def __call__(self, sample):\n", 52 | " sample['original'] = functional.to_tensor(sample['original'])\n", 53 | " sample['digit'] = torch.as_tensor(np.asarray(sample['digit']),\n", 54 | " dtype=torch.int64)\n", 55 | " return sample\n", 56 | "\n", 57 | "\n", 58 | "class MaskImages:\n", 59 | " \"\"\"This torchvision image transformation prepares the MNIST digits to be\n", 60 | " used in the tutorial. Depending on the number of quadrants to be used as\n", 61 | " inputs (1, 2, or 3), the transformation masks the remaining (3, 2, 1)\n", 62 | " quadrant(s) setting their pixels with -1. Additionally, the transformation\n", 63 | " adds the target output in the sample dict as the complementary of the input\n", 64 | " \"\"\"\n", 65 | " def __init__(self, num_quadrant_inputs, mask_with=-1):\n", 66 | " if num_quadrant_inputs <= 0 or num_quadrant_inputs >= 4:\n", 67 | " raise ValueError('Number of quadrants as inputs must be 1, 2 or 3')\n", 68 | " self.num = num_quadrant_inputs\n", 69 | " self.mask_with = mask_with\n", 70 | "\n", 71 | " def __call__(self, sample):\n", 72 | " tensor = sample['original'].squeeze()\n", 73 | " out = tensor.detach().clone()\n", 74 | " h, w = tensor.shape\n", 75 | "\n", 76 | " # removes the bottom left quadrant from the target output\n", 77 | " out[h // 2:, :w // 2] = self.mask_with\n", 78 | " # if num of quadrants to be used as input is 2,\n", 79 | " # also removes the top left quadrant from the target output\n", 80 | " if self.num == 2:\n", 81 | " out[:, :w // 2] = self.mask_with\n", 82 | " # if num of quadrants to be used as input is 3,\n", 83 | " # also removes the top right quadrant from the target output\n", 84 | " if self.num == 3:\n", 85 | " out[:h // 2, :] = self.mask_with\n", 86 | "\n", 87 | " # now, sets the input as complementary\n", 88 | " inp = tensor.clone()\n", 89 | " inp[out != -1] = self.mask_with\n", 90 | "\n", 91 | " sample['input'] = inp\n", 92 | " sample['output'] = out\n", 93 | " return sample\n", 94 | "\n", 95 | "\n", 96 | "def get_data(num_quadrant_inputs, batch_size):\n", 97 | " transforms = Compose([\n", 98 | " ToTensor(),\n", 99 | " MaskImages(num_quadrant_inputs=num_quadrant_inputs)\n", 100 | " ])\n", 101 | " datasets, dataloaders, dataset_sizes = {}, {}, {}\n", 102 | " for mode in ['train', 'val']:\n", 103 | " datasets[mode] = CVAEMNIST(\n", 104 | " '../data',\n", 105 | " download=True,\n", 106 | " transform=transforms,\n", 107 | " train=mode == 'train'\n", 108 | " )\n", 109 | " dataloaders[mode] = DataLoader(\n", 110 | " datasets[mode],\n", 111 | " batch_size=batch_size,\n", 112 | " shuffle=mode == 'train',\n", 113 | " num_workers=0\n", 114 | " )\n", 115 | " dataset_sizes[mode] = len(datasets[mode])\n", 116 | "\n", 117 | " return datasets, dataloaders, dataset_sizes" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "## Baseline: Deterministic Neural Network\n", 125 | "Before we dive into the CVAE implementation, let's code the baseline model. It is a straightforward implementation:" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "class BaselineNet(nn.Module):\n", 135 | " def __init__(self, hidden_1, hidden_2):\n", 136 | " super().__init__()\n", 137 | " self.fc1 = nn.Linear(784, hidden_1)\n", 138 | " self.fc2 = nn.Linear(hidden_1, hidden_2)\n", 139 | " self.fc3 = nn.Linear(hidden_2, 784)\n", 140 | " self.relu = nn.ReLU()\n", 141 | "\n", 142 | " def forward(self, x):\n", 143 | " x = x.view(-1, 784)\n", 144 | " hidden = self.relu(self.fc1(x))\n", 145 | " hidden = self.relu(self.fc2(hidden))\n", 146 | " y = torch.sigmoid(self.fc3(hidden))\n", 147 | " return y" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "In the paper, the authors compare the baseline NN with the proposed CVAE by comparing the negative (Conditional) Log Likelihood (CLL), averaged by image in the validation set. Thanks to PyTorch, computing the CLL is equivalent to computing the Binary Cross Entropy Loss using as input a signal passed through a Sigmoid layer. The code below does a small adjustment to leverage this: it only computes the loss in the pixels not masked with -1:" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "class MaskedBCELoss(nn.Module):\n", 164 | " def __init__(self, masked_with=-1):\n", 165 | " super().__init__()\n", 166 | " self.masked_with = masked_with\n", 167 | "\n", 168 | " def forward(self, input, target):\n", 169 | " target = target.view(input.shape)\n", 170 | " loss = F.binary_cross_entropy(input, target, reduction='none')\n", 171 | " loss[target == self.masked_with] = 0\n", 172 | " return loss.sum()" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "The training is very straightforward. We use 500 neurons in each hidden layer, Adam optimizer with `1e-3` learning rate, and early stopping. Please check the [Github repo](https://github.com/pyro-ppl/pyro/blob/dev/examples/cvae) for the full implementation.\n", 180 | "\n", 181 | "## Deep Conditional Generative Models for Structured Output Prediction\n", 182 | "As illustrated in the image below, there are three types of variables in a deep conditional generative model (CGM): input variables $\\bf x$, output variables $\\bf y$, and latent variables $\\bf z$. The conditional generative process of the model is given in (b) as follows: for given observation $\\bf x$, $\\bf z$ is drawn from the prior distribution $p_{\\theta}({\\bf z} | {\\bf x})$, and the output $\\bf y$ is generated from the distribution $p_{\\theta}({\\bf y} | {\\bf x, z})$. Compared to the baseline NN (a), the latent variables $\\bf z$ allow for modeling multiple modes in conditional distribution of output variables $\\bf y$ given input $\\bf x$, making the proposed CGM suitable for modeling one-to-many mapping.\n", 183 | "\n", 184 | "\n", 185 | "\"image1\"\n", 186 | "\n", 187 | "Deep CGMs are trained to maximize the conditional marginal log-likelihood. Often the objective function is intractable, and we apply the SGVB framework to train the model. The empirical lower bound is written as:\n", 188 | "\n", 189 | "$$ \\tilde{\\mathcal{L}}_{\\text{CVAE}}(x, y; \\theta, \\phi) = -KL(q_{\\phi}(z | x, y) || p_{\\theta}(z | x)) + \\frac{1}{L}\\sum_{l=1}^{L}\\log p_{\\theta}(y | x, z^{(l)}) $$\n", 190 | "\n", 191 | "where $\\bf z^{(l)}$ is a Gaussian latent variable product, and $L$ is the number of samples (or particles in Pyro nomenclature).\n", 192 | "We call this model conditional variational auto-encoder (CVAE). The CVAE is composed of multiple MLPs, such as **recognition network** $q_{\\phi}({\\bf z} | \\bf{x, y})$, **(conditional) prior network** $p_{\\theta}(\\bf{z} | \\bf{x})$, and **generation network** $p_{\\theta}(\\bf{y} | \\bf{x, z})$. In designing the network architecture, we build the network components of the CVAE **on top of the baseline NN**. Specifically, as shown in (d) above, not only the direct input $\\bf x$, but also the initial guess $\\hat{y}$ made by the NN are fed into the prior network. \n", 193 | "\n", 194 | "Pyro makes it really easy to translate this architecture into code. The recognition network and the (conditional) prior network are encoders from the traditional VAE setting, while the generation network is the decoder:" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "class Encoder(nn.Module):\n", 204 | " def __init__(self, z_dim, hidden_1, hidden_2):\n", 205 | " super().__init__()\n", 206 | " self.fc1 = nn.Linear(784, hidden_1)\n", 207 | " self.fc2 = nn.Linear(hidden_1, hidden_2)\n", 208 | " self.fc31 = nn.Linear(hidden_2, z_dim)\n", 209 | " self.fc32 = nn.Linear(hidden_2, z_dim)\n", 210 | " self.relu = nn.ReLU()\n", 211 | "\n", 212 | " def forward(self, x, y):\n", 213 | " # put x and y together in the same image for simplification\n", 214 | " xc = x.clone()\n", 215 | " xc[x == -1] = y[x == -1]\n", 216 | " xc = xc.view(-1, 784)\n", 217 | " # then compute the hidden units\n", 218 | " hidden = self.relu(self.fc1(xc))\n", 219 | " hidden = self.relu(self.fc2(hidden))\n", 220 | " # then return a mean vector and a (positive) square root covariance\n", 221 | " # each of size batch_size x z_dim\n", 222 | " z_loc = self.fc31(hidden)\n", 223 | " z_scale = torch.exp(self.fc32(hidden))\n", 224 | " return z_loc, z_scale\n", 225 | "\n", 226 | "\n", 227 | "class Decoder(nn.Module):\n", 228 | " def __init__(self, z_dim, hidden_1, hidden_2):\n", 229 | " super().__init__()\n", 230 | " self.fc1 = nn.Linear(z_dim, hidden_1)\n", 231 | " self.fc2 = nn.Linear(hidden_1, hidden_2)\n", 232 | " self.fc3 = nn.Linear(hidden_2, 784)\n", 233 | " self.relu = nn.ReLU()\n", 234 | "\n", 235 | " def forward(self, z):\n", 236 | " y = self.relu(self.fc1(z))\n", 237 | " y = self.relu(self.fc2(y))\n", 238 | " y = torch.sigmoid(self.fc3(y))\n", 239 | " return y\n", 240 | "\n", 241 | "\n", 242 | "class CVAE(nn.Module):\n", 243 | " def __init__(self, z_dim, hidden_1, hidden_2, pre_trained_baseline_net):\n", 244 | " super().__init__()\n", 245 | " # The CVAE is composed of multiple MLPs, such as recognition network\n", 246 | " # qφ(z|x, y), (conditional) prior network pθ(z|x), and generation\n", 247 | " # network pθ(y|x, z). Also, CVAE is built on top of the NN: not only\n", 248 | " # the direct input x, but also the initial guess y_hat made by the NN\n", 249 | " # are fed into the prior network.\n", 250 | " self.baseline_net = pre_trained_baseline_net\n", 251 | " self.prior_net = Encoder(z_dim, hidden_1, hidden_2)\n", 252 | " self.generation_net = Decoder(z_dim, hidden_1, hidden_2)\n", 253 | " self.recognition_net = Encoder(z_dim, hidden_1, hidden_2)\n", 254 | "\n", 255 | " def model(self, xs, ys=None):\n", 256 | " # register this pytorch module and all of its sub-modules with pyro\n", 257 | " pyro.module(\"generation_net\", self)\n", 258 | " batch_size = xs.shape[0]\n", 259 | " with pyro.plate(\"data\"):\n", 260 | "\n", 261 | " # Prior network uses the baseline predictions as initial guess.\n", 262 | " # This is the generative process with recurrent connection\n", 263 | " with torch.no_grad():\n", 264 | " # this ensures the training process does not change the\n", 265 | " # baseline network\n", 266 | " y_hat = self.baseline_net(xs).view(xs.shape)\n", 267 | "\n", 268 | " # sample the handwriting style from the prior distribution, which is\n", 269 | " # modulated by the input xs.\n", 270 | " prior_loc, prior_scale = self.prior_net(xs, y_hat)\n", 271 | " zs = pyro.sample('z', dist.Normal(prior_loc, prior_scale).to_event(1))\n", 272 | "\n", 273 | " # the output y is generated from the distribution pθ(y|x, z)\n", 274 | " loc = self.generation_net(zs)\n", 275 | "\n", 276 | " if ys is not None:\n", 277 | " # In training, we will only sample in the masked image\n", 278 | " mask_loc = loc[(xs == -1).view(-1, 784)].view(batch_size, -1)\n", 279 | " mask_ys = ys[xs == -1].view(batch_size, -1)\n", 280 | " pyro.sample('y', dist.Bernoulli(mask_loc).to_event(1), obs=mask_ys)\n", 281 | " else:\n", 282 | " # In testing, no need to sample: the output is already a\n", 283 | " # probability in [0, 1] range, which better represent pixel\n", 284 | " # values considering grayscale. If we sample, we will force\n", 285 | " # each pixel to be either 0 or 1, killing the grayscale\n", 286 | " pyro.deterministic('y', loc.detach())\n", 287 | "\n", 288 | " # return the loc so we can visualize it later\n", 289 | " return loc\n", 290 | "\n", 291 | " def guide(self, xs, ys=None):\n", 292 | " with pyro.plate(\"data\"):\n", 293 | " if ys is None:\n", 294 | " # at inference time, ys is not provided. In that case,\n", 295 | " # the model uses the prior network\n", 296 | " y_hat = self.baseline_net(xs).view(xs.shape)\n", 297 | " loc, scale = self.prior_net(xs, y_hat)\n", 298 | " else:\n", 299 | " # at training time, uses the variational distribution\n", 300 | " # q(z|x,y) = normal(loc(x,y),scale(x,y))\n", 301 | " loc, scale = self.recognition_net(xs, ys)\n", 302 | "\n", 303 | " pyro.sample(\"z\", dist.Normal(loc, scale).to_event(1))\n", 304 | "\n", 305 | " def save(self, model_path):\n", 306 | " torch.save({\n", 307 | " 'prior': self.prior_net.state_dict(),\n", 308 | " 'generation': self.generation_net.state_dict(),\n", 309 | " 'recognition': self.recognition_net.state_dict()\n", 310 | " }, model_path)\n", 311 | "\n", 312 | " def load(self, model_path, map_location=None):\n", 313 | " net_weights = torch.load(model_path, map_location=map_location)\n", 314 | " self.prior_net.load_state_dict(net_weights['prior'])\n", 315 | " self.generation_net.load_state_dict(net_weights['generation'])\n", 316 | " self.recognition_net.load_state_dict(net_weights['recognition'])\n", 317 | " self.prior_net.eval()\n", 318 | " self.generation_net.eval()\n", 319 | " self.recognition_net.eval()" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": {}, 325 | "source": [ 326 | "## Evaluating the results\n", 327 | "For qualitative analysis, we visualize the generated output samples in the next figure. As we can see, the baseline NNs can only make a single deterministic prediction, and as a result the output looks blurry and doesn’t look realistic in many cases. In contrast, the samples generated by the CVAE models are more realistic and diverse in shape; sometimes they can even change their identity (digit labels), such as from 3 to 5 or from 4 to 9, and vice versa.\n", 328 | "\n", 329 | "\"image1\"\n", 330 | "\n", 331 | "We also provide a quantitative evidence by estimating the marginal conditional log-likelihoods (CLLs) in next table. \n", 332 | "\n", 333 | "| | 1 quadrant | 2 quadrants | 3 quadrants |\n", 334 | "|--------------------|------------|-------------|-------------|\n", 335 | "| NN (baseline) | 100.4 | 61.9 | 25.4 |\n", 336 | "| CVAE (Monte Carlo) | 71.8 | 51.0 | 24.2 |\n", 337 | "| Performance gap | 28.6 | 10.9 | 1.2 |\n", 338 | "\n", 339 | "We achieved similar results to the ones achieved by the authors in the paper. We trained only for 50 epochs with early stopping patience of 3 epochs; to improve the results, we could leave the algorithm training for longer. Nevertheless, we can observe the same effect shown in the paper: **the estimated CLLs of the CVAE significantly outperforms the baseline NN**.\n", 340 | "\n", 341 | "See the full code on [Github](https://github.com/pyro-ppl/pyro/blob/dev/examples/cvae).\n", 342 | "\n", 343 | "## References\n", 344 | "\n", 345 | "[1] `Learning Structured Output Representation using Deep Conditional Generative Models`,
    \n", 346 | "Kihyuk Sohn, Xinchen Yan, Honglak Lee" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [] 355 | } 356 | ], 357 | "metadata": { 358 | "kernelspec": { 359 | "display_name": "Python 3", 360 | "language": "python", 361 | "name": "python3" 362 | }, 363 | "language_info": { 364 | "codemirror_mode": { 365 | "name": "ipython", 366 | "version": 3 367 | }, 368 | "file_extension": ".py", 369 | "mimetype": "text/x-python", 370 | "name": "python", 371 | "nbconvert_exporter": "python", 372 | "pygments_lexer": "ipython3", 373 | "version": "3.7.6" 374 | } 375 | }, 376 | "nbformat": 4, 377 | "nbformat_minor": 4 378 | } 379 | -------------------------------------------------------------------------------- /src/cvae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from pathlib import Path 4 | import pyro 5 | import pyro.distributions as dist 6 | from pyro.infer import SVI, Trace_ELBO, Predictive 7 | import torch 8 | import torch.nn as nn 9 | from tqdm import tqdm 10 | from mnist import get_val_images 11 | 12 | 13 | class Encoder(nn.Module): 14 | def __init__(self, z_dim, hidden_1, hidden_2): 15 | super().__init__() 16 | self.fc1 = nn.Linear(784, hidden_1) 17 | self.fc2 = nn.Linear(hidden_1, hidden_2) 18 | self.fc31 = nn.Linear(hidden_2, z_dim) 19 | self.fc32 = nn.Linear(hidden_2, z_dim) 20 | self.relu = nn.ReLU() 21 | 22 | def forward(self, x, y): 23 | # put x and y together in the same image for simplification 24 | xc = x.clone() 25 | xc[x == -1] = y[x == -1] 26 | xc = xc.view(-1, 784) 27 | # then compute the hidden units 28 | hidden = self.relu(self.fc1(xc)) 29 | hidden = self.relu(self.fc2(hidden)) 30 | # then return a mean vector and a (positive) square root covariance 31 | # each of size batch_size x z_dim 32 | z_loc = self.fc31(hidden) 33 | z_scale = torch.exp(self.fc32(hidden)) 34 | return z_loc, z_scale 35 | 36 | 37 | class Decoder(nn.Module): 38 | def __init__(self, z_dim, hidden_1, hidden_2): 39 | super().__init__() 40 | self.fc1 = nn.Linear(z_dim, hidden_1) 41 | self.fc2 = nn.Linear(hidden_1, hidden_2) 42 | self.fc3 = nn.Linear(hidden_2, 784) 43 | self.relu = nn.ReLU() 44 | 45 | def forward(self, z): 46 | y = self.relu(self.fc1(z)) 47 | y = self.relu(self.fc2(y)) 48 | y = torch.sigmoid(self.fc3(y)) 49 | return y 50 | 51 | 52 | class CVAE(nn.Module): 53 | def __init__(self, z_dim, hidden_1, hidden_2, pre_trained_baseline_net): 54 | super().__init__() 55 | # The CVAE is composed of multiple MLPs, such as recognition network 56 | # qφ(z|x, y), (conditional) prior network pθ(z|x), and generation 57 | # network pθ(y|x, z). Also, CVAE is built on top of the NN: not only 58 | # the direct input x, but also the initial guess y_hat made by the NN 59 | # are fed into the prior network. 60 | self.baseline_net = pre_trained_baseline_net 61 | self.prior_net = Encoder(z_dim, hidden_1, hidden_2) 62 | self.generation_net = Decoder(z_dim, hidden_1, hidden_2) 63 | self.recognition_net = Encoder(z_dim, hidden_1, hidden_2) 64 | 65 | def model(self, xs, ys=None): 66 | # register this pytorch module and all of its sub-modules with pyro 67 | pyro.module("generation_net", self) 68 | batch_size = xs.shape[0] 69 | with pyro.plate("data"): 70 | 71 | # Prior network uses the baseline predictions as initial guess. 72 | # This is the generative process with recurrent connection 73 | with torch.no_grad(): 74 | # this ensures the training process does not change the 75 | # baseline network 76 | y_hat = self.baseline_net(xs).view(xs.shape) 77 | 78 | # sample the handwriting style from the prior distribution, which is 79 | # modulated by the input xs. 80 | prior_loc, prior_scale = self.prior_net(xs, y_hat) 81 | zs = pyro.sample('z', dist.Normal(prior_loc, prior_scale).to_event(1)) 82 | 83 | # the output y is generated from the distribution pθ(y|x, z) 84 | loc = self.generation_net(zs) 85 | 86 | if ys is not None: 87 | # In training, we will only sample in the masked image 88 | mask_loc = loc[(xs == -1).view(-1, 784)].view(batch_size, -1) 89 | mask_ys = ys[xs == -1].view(batch_size, -1) 90 | pyro.sample('y', dist.Bernoulli(mask_loc).to_event(1), obs=mask_ys) 91 | else: 92 | # In testing, no need to sample: the output is already a 93 | # probability in [0, 1] range, which better represent pixel 94 | # values considering grayscale. If we sample, we will force 95 | # each pixel to be either 0 or 1, killing the grayscale 96 | pyro.deterministic('y', loc.detach()) 97 | 98 | # return the loc so we can visualize it later 99 | return loc 100 | 101 | def guide(self, xs, ys=None): 102 | with pyro.plate("data"): 103 | if ys is None: 104 | # at inference time, ys is not provided. In that case, 105 | # the model uses the prior network 106 | y_hat = self.baseline_net(xs).view(xs.shape) 107 | loc, scale = self.prior_net(xs, y_hat) 108 | else: 109 | # at training time, uses the variational distribution 110 | # q(z|x,y) = normal(loc(x,y),scale(x,y)) 111 | loc, scale = self.recognition_net(xs, ys) 112 | 113 | pyro.sample("z", dist.Normal(loc, scale).to_event(1)) 114 | 115 | def save(self, model_path): 116 | torch.save({ 117 | 'prior': self.prior_net.state_dict(), 118 | 'generation': self.generation_net.state_dict(), 119 | 'recognition': self.recognition_net.state_dict() 120 | }, model_path) 121 | 122 | def load(self, model_path, map_location=None): 123 | net_weights = torch.load(model_path, map_location=map_location) 124 | self.prior_net.load_state_dict(net_weights['prior']) 125 | self.generation_net.load_state_dict(net_weights['generation']) 126 | self.recognition_net.load_state_dict(net_weights['recognition']) 127 | self.prior_net.eval() 128 | self.generation_net.eval() 129 | self.recognition_net.eval() 130 | 131 | 132 | def train(device, dataloaders, dataset_sizes, learning_rate, num_epochs, 133 | early_stop_patience, model_path, pre_trained_baseline_net): 134 | 135 | # clear param store 136 | pyro.clear_param_store() 137 | 138 | cvae_net = CVAE(200, 500, 500, pre_trained_baseline_net) 139 | cvae_net.to(device) 140 | optimizer = pyro.optim.Adam({"lr": learning_rate}) 141 | svi = SVI(cvae_net.model, cvae_net.guide, optimizer, loss=Trace_ELBO()) 142 | 143 | best_loss = np.inf 144 | early_stop_count = 0 145 | Path(model_path).parent.mkdir(parents=True, exist_ok=True) 146 | 147 | # to track evolution 148 | val_inp, digits = get_val_images(num_quadrant_inputs=1, 149 | num_images=30, shuffle=False) 150 | val_inp = val_inp.to(device) 151 | samples = [] 152 | losses = [] 153 | 154 | for epoch in range(num_epochs): 155 | # Each epoch has a training and validation phase 156 | for phase in ['train', 'val']: 157 | running_loss = 0.0 158 | 159 | # Iterate over data. 160 | bar = tqdm(dataloaders[phase], 161 | desc='CVAE Epoch {} {}'.format(epoch, phase).ljust(20)) 162 | for i, batch in enumerate(bar): 163 | inputs = batch['input'].to(device) 164 | outputs = batch['output'].to(device) 165 | 166 | if phase == 'train': 167 | loss = svi.step(inputs, outputs) / inputs.size(0) 168 | else: 169 | loss = svi.evaluate_loss(inputs, outputs) / inputs.size(0) 170 | 171 | # statistics 172 | running_loss += loss 173 | if i % 10 == 0: 174 | bar.set_postfix(loss='{:.2f}'.format(loss), 175 | early_stop_count=early_stop_count) 176 | 177 | # track evolution 178 | if phase == 'train': 179 | df = pd.DataFrame(columns=['epoch', 'loss']) 180 | df.loc[0] = [epoch + float(i) / len(dataloaders[phase]), loss] 181 | losses.append(df) 182 | if i % 47 == 0: # every 10% of training (469) 183 | dfs = predict_samples( 184 | val_inp, digits, cvae_net, 185 | epoch + float(i) / len(dataloaders[phase]), 186 | ) 187 | samples.append(dfs) 188 | 189 | epoch_loss = running_loss / dataset_sizes[phase] 190 | # deep copy the model 191 | if phase == 'val': 192 | if epoch_loss < best_loss: 193 | best_loss = epoch_loss 194 | cvae_net.save(model_path) 195 | early_stop_count = 0 196 | else: 197 | early_stop_count += 1 198 | 199 | if early_stop_count >= early_stop_patience: 200 | break 201 | 202 | # Save model weights 203 | cvae_net.load(model_path) 204 | 205 | # record evolution 206 | samples = pd.concat(samples, axis=0, ignore_index=True) 207 | samples.to_csv('samples.csv', index=False) 208 | 209 | losses = pd.concat(losses, axis=0, ignore_index=True) 210 | losses.to_csv('losses.csv', index=False) 211 | 212 | return cvae_net 213 | 214 | 215 | def predict_samples(inputs, digits, pre_trained_cvae, epoch_frac): 216 | predictive = Predictive(pre_trained_cvae.model, 217 | guide=pre_trained_cvae.guide, 218 | num_samples=1) 219 | preds = predictive(inputs) 220 | y_loc = preds['y'].squeeze().detach().cpu().numpy() 221 | dfs = pd.DataFrame(data=y_loc) 222 | dfs['digit'] = digits.numpy() 223 | dfs['epoch'] = epoch_frac 224 | return dfs 225 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import pyro 4 | import torch 5 | import baseline 6 | import cvae 7 | from util import get_data, visualize, generate_table 8 | 9 | 10 | def main(args): 11 | device = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda 12 | else "cpu") 13 | results = [] 14 | columns = [] 15 | 16 | for num_quadrant_inputs in args.num_quadrant_inputs: 17 | # adds an s in case of plural quadrants 18 | maybes = 's' if num_quadrant_inputs > 1 else '' 19 | 20 | print('Training with {} quadrant{} as input...' 21 | .format(num_quadrant_inputs, maybes)) 22 | 23 | # Dataset 24 | datasets, dataloaders, dataset_sizes = get_data( 25 | num_quadrant_inputs=num_quadrant_inputs, 26 | batch_size=128 27 | ) 28 | 29 | # Train baseline 30 | baseline_net = baseline.train( 31 | device=device, 32 | dataloaders=dataloaders, 33 | dataset_sizes=dataset_sizes, 34 | learning_rate=args.learning_rate, 35 | num_epochs=args.num_epochs, 36 | early_stop_patience=args.early_stop_patience, 37 | model_path='baseline_net_q{}.pth'.format(num_quadrant_inputs) 38 | ) 39 | 40 | # Train CVAE 41 | cvae_net = cvae.train( 42 | device=device, 43 | dataloaders=dataloaders, 44 | dataset_sizes=dataset_sizes, 45 | learning_rate=args.learning_rate, 46 | num_epochs=args.num_epochs, 47 | early_stop_patience=args.early_stop_patience, 48 | model_path='cvae_net_q{}.pth'.format(num_quadrant_inputs), 49 | pre_trained_baseline_net=baseline_net 50 | ) 51 | 52 | # Visualize conditional predictions 53 | visualize( 54 | device=device, 55 | num_quadrant_inputs=num_quadrant_inputs, 56 | pre_trained_baseline=baseline_net, 57 | pre_trained_cvae=cvae_net, 58 | num_images=args.num_images, 59 | num_samples=args.num_samples, 60 | image_path='cvae_plot_q{}.png'.format(num_quadrant_inputs) 61 | ) 62 | 63 | # Retrieve conditional log likelihood 64 | df = generate_table( 65 | device=device, 66 | num_quadrant_inputs=num_quadrant_inputs, 67 | pre_trained_baseline=baseline_net, 68 | pre_trained_cvae=cvae_net, 69 | num_particles=args.num_particles, 70 | col_name='{} quadrant{}'.format(num_quadrant_inputs, maybes) 71 | ) 72 | results.append(df) 73 | columns.append('{} quadrant{}'.format(num_quadrant_inputs, maybes)) 74 | 75 | results = pd.concat(results, axis=1, ignore_index=True) 76 | results.columns = columns 77 | results.loc['Performance gap', :] = results.iloc[0, :] - results.iloc[1, :] 78 | results.to_csv('results.csv') 79 | 80 | 81 | if __name__ == '__main__': 82 | assert pyro.__version__.startswith('1.4.0') 83 | # parse command line arguments 84 | parser = argparse.ArgumentParser(description="parse args") 85 | parser.add_argument('-nq', '--num-quadrant-inputs', metavar='N', type=int, 86 | nargs='+', default=[1, 2, 3], 87 | help='num of quadrants to use as inputs') 88 | parser.add_argument('-n', '--num-epochs', default=101, type=int, 89 | help='number of training epochs') 90 | parser.add_argument('-esp', '--early-stop-patience', default=3, type=int, 91 | help='early stop patience') 92 | parser.add_argument('-lr', '--learning-rate', default=1.0e-3, type=float, 93 | help='learning rate') 94 | parser.add_argument('--cuda', action='store_true', default=False, 95 | help='whether to use cuda') 96 | parser.add_argument('-vi', '--num-images', default=10, type=int, 97 | help='number of images to visualize') 98 | parser.add_argument('-vs', '--num-samples', default=10, type=int, 99 | help='number of samples to visualize per image') 100 | parser.add_argument('-p', '--num-particles', default=10, type=int, 101 | help='n of particles to estimate logpθ(y|x,z) in ELBO') 102 | args = parser.parse_args() 103 | 104 | main(args) 105 | -------------------------------------------------------------------------------- /src/mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from torchvision.datasets import MNIST 5 | from torchvision.transforms import Compose, functional 6 | 7 | 8 | class CVAEMNIST(Dataset): 9 | def __init__(self, root, train=True, transform=None, download=False): 10 | self.original = MNIST(root, train=train, download=download) 11 | self.transform = transform 12 | 13 | def __len__(self): 14 | return len(self.original) 15 | 16 | def __getitem__(self, item): 17 | image, digit = self.original[item] 18 | sample = {'original': image, 'digit': digit} 19 | if self.transform: 20 | sample = self.transform(sample) 21 | 22 | return sample 23 | 24 | 25 | class ToTensor: 26 | def __call__(self, sample): 27 | sample['original'] = functional.to_tensor(sample['original']) 28 | sample['digit'] = torch.as_tensor(np.asarray(sample['digit']), 29 | dtype=torch.int64) 30 | return sample 31 | 32 | 33 | class MaskImages: 34 | """This torchvision image transformation prepares the MNIST digits to be 35 | used in the tutorial. Depending on the number of quadrants to be used as 36 | inputs (1, 2, or 3), the transformation masks the remaining (3, 2, 1) 37 | quadrant(s) setting their pixels with -1. Additionally, the transformation 38 | adds the target output in the sample dict as the complementary of the input 39 | """ 40 | def __init__(self, num_quadrant_inputs, mask_with=-1): 41 | if num_quadrant_inputs <= 0 or num_quadrant_inputs >= 4: 42 | raise ValueError('Number of quadrants as inputs must be 1, 2 or 3') 43 | self.num = num_quadrant_inputs 44 | self.mask_with = mask_with 45 | 46 | def __call__(self, sample): 47 | tensor = sample['original'].squeeze() 48 | out = tensor.detach().clone() 49 | h, w = tensor.shape 50 | 51 | # removes the bottom left quadrant from the target output 52 | out[h // 2:, :w // 2] = self.mask_with 53 | # if num of quadrants to be used as input is 2, 54 | # also removes the top left quadrant from the target output 55 | if self.num == 2: 56 | out[:, :w // 2] = self.mask_with 57 | # if num of quadrants to be used as input is 3, 58 | # also removes the top right quadrant from the target output 59 | if self.num == 3: 60 | out[:h // 2, :] = self.mask_with 61 | 62 | # now, sets the input as complementary 63 | inp = tensor.clone() 64 | inp[out != -1] = self.mask_with 65 | 66 | sample['input'] = inp 67 | sample['output'] = out 68 | return sample 69 | 70 | 71 | def get_data(num_quadrant_inputs, batch_size): 72 | transforms = Compose([ 73 | ToTensor(), 74 | MaskImages(num_quadrant_inputs=num_quadrant_inputs) 75 | ]) 76 | datasets, dataloaders, dataset_sizes = {}, {}, {} 77 | for mode in ['train', 'val']: 78 | datasets[mode] = CVAEMNIST( 79 | '../data', 80 | download=True, 81 | transform=transforms, 82 | train=mode == 'train' 83 | ) 84 | dataloaders[mode] = DataLoader( 85 | datasets[mode], 86 | batch_size=batch_size, 87 | shuffle=mode == 'train', 88 | num_workers=0 89 | ) 90 | dataset_sizes[mode] = len(datasets[mode]) 91 | 92 | return datasets, dataloaders, dataset_sizes 93 | 94 | 95 | def get_val_images(num_quadrant_inputs, num_images, shuffle): 96 | datasets, _, dataset_sizes = get_data( 97 | num_quadrant_inputs=num_quadrant_inputs, 98 | batch_size=num_images 99 | ) 100 | dataloader = DataLoader(datasets['val'], batch_size=num_images, 101 | shuffle=shuffle) 102 | 103 | batch = next(iter(dataloader)) 104 | inputs = batch['input'] 105 | digits = batch['digit'] 106 | return inputs, digits 107 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | from pathlib import Path 5 | from pyro.infer import Predictive, Trace_ELBO 6 | from sklearn.manifold import TSNE 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from torchvision.utils import make_grid 10 | from tqdm import tqdm 11 | from baseline import MaskedBCELoss, BaselineNet 12 | from mnist import get_data, get_val_images 13 | from cvae import CVAE 14 | 15 | 16 | def imshow(inp, image_path=None): 17 | inp = inp.cpu().numpy().transpose((1, 2, 0)) 18 | space = np.ones((inp.shape[0], 50, inp.shape[2])) 19 | inp = np.concatenate([space, inp], axis=1) 20 | 21 | ax = plt.axes(frameon=False, xticks=[], yticks=[]) 22 | ax.text(0, 23, 'Inputs:') 23 | ax.text(0, 23 + 28 + 3, 'Truth:') 24 | ax.text(0, 23 + (28 + 3) * 2, 'NN:') 25 | ax.text(0, 23 + (28 + 3) * 3, 'CVAE:') 26 | ax.imshow(inp) 27 | 28 | if image_path is not None: 29 | Path(image_path).parent.mkdir(parents=True, exist_ok=True) 30 | plt.savefig(image_path, bbox_inches='tight', pad_inches=0.1) 31 | else: 32 | plt.show() 33 | 34 | plt.clf() 35 | 36 | 37 | def visualize(device, num_quadrant_inputs, pre_trained_baseline, 38 | pre_trained_cvae, num_images, num_samples, image_path=None): 39 | 40 | # Load sample random data 41 | datasets, _, dataset_sizes = get_data( 42 | num_quadrant_inputs=num_quadrant_inputs, 43 | batch_size=num_images 44 | ) 45 | dataloader = DataLoader(datasets['val'], batch_size=num_images, shuffle=True) 46 | 47 | batch = next(iter(dataloader)) 48 | inputs = batch['input'].to(device) 49 | outputs = batch['output'].to(device) 50 | originals = batch['original'].to(device) 51 | 52 | # Make predictions 53 | with torch.no_grad(): 54 | baseline_preds = pre_trained_baseline(inputs).view(outputs.shape) 55 | 56 | predictive = Predictive(pre_trained_cvae.model, 57 | guide=pre_trained_cvae.guide, 58 | num_samples=num_samples) 59 | cvae_preds = predictive(inputs)['y'].view(num_samples, num_images, 28, 28) 60 | 61 | # Predictions are only made in the pixels not masked. This completes 62 | # the input quadrant with the prediction for the missing quadrants, for 63 | # visualization purpose 64 | baseline_preds[outputs == -1] = inputs[outputs == -1] 65 | for i in range(cvae_preds.shape[0]): 66 | cvae_preds[i][outputs == -1] = inputs[outputs == -1] 67 | 68 | # adjust tensor sizes 69 | inputs = inputs.unsqueeze(1) 70 | inputs[inputs == -1] = 1 71 | baseline_preds = baseline_preds.unsqueeze(1) 72 | cvae_preds = cvae_preds.view(-1, 28, 28).unsqueeze(1) 73 | 74 | # make grids 75 | inputs_tensor = make_grid(inputs, nrow=num_images, padding=0) 76 | originals_tensor = make_grid(originals, nrow=num_images, padding=0) 77 | separator_tensor = torch.ones((3, 5, originals_tensor.shape[-1])).to(device) 78 | baseline_tensor = make_grid(baseline_preds, nrow=num_images, padding=0) 79 | cvae_tensor = make_grid(cvae_preds, nrow=num_images, padding=0) 80 | 81 | # add vertical and horizontal lines 82 | for tensor in [originals_tensor, baseline_tensor, cvae_tensor]: 83 | for i in range(num_images - 1): 84 | tensor[:, :, (i + 1) * 28] = 0.3 85 | 86 | for i in range(num_samples - 1): 87 | cvae_tensor[:, (i + 1) * 28, :] = 0.3 88 | 89 | # concatenate all tensors 90 | grid_tensor = torch.cat([inputs_tensor, separator_tensor, originals_tensor, 91 | separator_tensor, baseline_tensor, 92 | separator_tensor, cvae_tensor], dim=1) 93 | # plot tensors 94 | imshow(grid_tensor, image_path=image_path) 95 | 96 | 97 | def generate_table(device, num_quadrant_inputs, pre_trained_baseline, 98 | pre_trained_cvae, num_particles, col_name): 99 | 100 | # Load sample random data 101 | datasets, dataloaders, dataset_sizes = get_data( 102 | num_quadrant_inputs=num_quadrant_inputs, 103 | batch_size=32 104 | ) 105 | 106 | # Load sample data 107 | criterion = MaskedBCELoss() 108 | loss_fn = Trace_ELBO(num_particles=num_particles).differentiable_loss 109 | 110 | baseline_cll = 0.0 111 | cvae_mc_cll = 0.0 112 | num_preds = 0 113 | 114 | df = pd.DataFrame(index=['NN (baseline)', 'CVAE (Monte Carlo)'], 115 | columns=[col_name]) 116 | 117 | # Iterate over data. 118 | bar = tqdm(dataloaders['val'], desc='Generating predictions'.ljust(20)) 119 | for batch in bar: 120 | inputs = batch['input'].to(device) 121 | outputs = batch['output'].to(device) 122 | num_preds += 1 123 | 124 | # Compute negative log likelihood for the baseline NN 125 | with torch.no_grad(): 126 | preds = pre_trained_baseline(inputs) 127 | baseline_cll += criterion(preds, outputs).item() / inputs.size(0) 128 | 129 | # Compute the negative conditional log likelihood for the CVAE 130 | cvae_mc_cll += loss_fn(pre_trained_cvae.model, 131 | pre_trained_cvae.guide, 132 | inputs, outputs).detach().item() / inputs.size(0) 133 | 134 | df.iloc[0, 0] = baseline_cll / num_preds 135 | df.iloc[1, 0] = cvae_mc_cll / num_preds 136 | return df 137 | 138 | 139 | if __name__ == '__main__': 140 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 141 | 142 | # Dataset 143 | datasets, dataloaders, dataset_sizes = get_data( 144 | num_quadrant_inputs=1, 145 | batch_size=128 146 | ) 147 | baseline_net = BaselineNet(500, 500) 148 | baseline_net.load_state_dict( 149 | torch.load('/Users/carlossouza/Downloads/baseline_net_q1.pth', 150 | map_location='cpu')) 151 | baseline_net.eval() 152 | 153 | cvae_net = CVAE(200, 500, 500, baseline_net) 154 | cvae_net.load_state_dict( 155 | torch.load('/Users/carlossouza/Downloads/cvae_net_q1.pth', 156 | map_location='cpu')) 157 | cvae_net.eval() 158 | 159 | visualize( 160 | device=device, 161 | num_quadrant_inputs=1, 162 | pre_trained_baseline=baseline_net, 163 | pre_trained_cvae=cvae_net, 164 | num_images=10, 165 | num_samples=10 166 | ) 167 | 168 | # df = generate_table( 169 | # device=device, 170 | # num_quadrant_inputs=1, 171 | # pre_trained_baseline=baseline_net, 172 | # pre_trained_cvae=cvae_net, 173 | # num_particles=10, 174 | # col_name='{} quadrant'.format(1) 175 | # ) 176 | # print(df) 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | -------------------------------------------------------------------------------- /src/video.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | from sys import platform 4 | if platform == 'linux': 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | 8 | import matplotlib.pyplot as plt 9 | import matplotlib.animation as animation 10 | import matplotlib.gridspec as gridspec 11 | import numpy as np 12 | from mnist import get_val_images 13 | import cv2 14 | from PIL import Image 15 | 16 | 17 | fig = plt.figure(figsize=(8, 7)) 18 | spec = gridspec.GridSpec(nrows=2, ncols=3, hspace=.3, wspace=.25) 19 | axl = fig.add_subplot(spec[0, :]) 20 | axd = [fig.add_subplot(spec[1, i]) for i in range(3)] 21 | plt.subplots_adjust(left=0.10, bottom=0.05, right=0.90, top=0.90, 22 | wspace=0.3, hspace=0.25) 23 | im = Image.open('../data/pyro_logo.png') 24 | im.thumbnail((900, 900), Image.ANTIALIAS) 25 | 26 | im2 = im.copy() 27 | im2.putalpha(30) 28 | im.paste(im2, im) 29 | 30 | im = np.array(im).astype(np.float) / 255 31 | fig.figimage(im, 600, 850) # 1200, 800 32 | 33 | 34 | def animate(i, dfs, dfl, inputs, digits): 35 | if i < len(dfl): 36 | axl.clear() 37 | axl.set_ylim(top=200, bottom=60) 38 | axl.plot(dfl.iloc[0:i, 0].values, dfl.iloc[0:i, 1].values) 39 | axl.set_ylabel('Loss') 40 | axl.set_xlabel('Epochs') 41 | axl.set_title('Training Progress') 42 | 43 | s = f'Loss: {dfl.iloc[i - 10:i, 1].mean():.2f}' 44 | axl.text(0.85, 0.96, s, horizontalalignment='left', 45 | verticalalignment='top', transform=axl.transAxes) 46 | s = f'Epoch: {dfl.iloc[i, 0]:.1f}' 47 | axl.text(0.85, 0.88, s, horizontalalignment='left', 48 | verticalalignment='top', transform=axl.transAxes) 49 | 50 | data = dfs[dfs['epoch'] == dfl['epoch'].iloc[i]] 51 | if len(data) > 0: 52 | for j, k in enumerate([0, 18, 4]): # index 15 also good 3-5 confusion 53 | img = data.iloc[k, :784].values.reshape(28, 28) 54 | inp = inputs[k] 55 | img[inp != -1] = inp[inp != -1] 56 | 57 | img = cv2.resize(img, dsize=(280, 280), 58 | interpolation=cv2.INTER_NEAREST) 59 | img = np.stack((img,)*3, axis=-1) 60 | img[140:, 140, 1] = 1 61 | img[140, :140, 1] = 1 62 | 63 | axd[j].clear() 64 | axd[j].imshow(img, cmap='gray') 65 | axd[j].set_title('Sample %d' % (digits[k])) 66 | axd[j].get_xaxis().set_visible(False) 67 | axd[j].get_yaxis().set_visible(False) 68 | 69 | 70 | def main(args): 71 | dfs = pd.read_csv('../data/samples.csv') 72 | dfl = pd.read_csv('../data/losses.csv') 73 | 74 | inputs, digits = get_val_images(num_quadrant_inputs=1, 75 | num_images=30, shuffle=False) 76 | inputs = inputs.numpy() 77 | 78 | Writer = animation.writers['ffmpeg'] 79 | writer = Writer(fps=15, metadata=dict(artist='Carlos Souza'), bitrate=1800) 80 | 81 | ani = animation.FuncAnimation(fig, animate, interval=50, frames=100000, 82 | fargs=(dfs, dfl, inputs, digits, )) 83 | 84 | if args.show: 85 | plt.show() 86 | else: 87 | ani.save('animation.mp4', writer=writer) 88 | 89 | 90 | if __name__ == '__main__': 91 | parser = argparse.ArgumentParser(description='Generate animation.') 92 | parser.add_argument('-s', '--show', action="store_true", default=False, 93 | help='Use this flag to show video animation on screen' 94 | 'instead of saving it to file. Default is to save' 95 | 'to "animation.mp4".') 96 | args = parser.parse_args() 97 | main(args) 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | --------------------------------------------------------------------------------