├── .gitignore
├── README.md
├── requirements.txt
└── src
├── baseline.py
├── cvae.ipynb
├── cvae.py
├── main.py
├── mnist.py
├── util.py
└── video.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | src/.ipynb_checkpoints
3 | src/__pycache__
4 | data/
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Conditional Variational Auto-encoder
2 | [](https://paperswithcode.com/sota/structured-prediction-on-mnist?p=learning-structured-output-representation)
3 |
4 | ## Introduction
5 |
6 | This tutorial implements [Learning Structured Output Representation using Deep Conditional Generative Models](http://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-generati) paper, which introduced Conditional Variational Auto-encoders in 2015, using Pyro PPL.
7 |
8 | Supervised deep learning has been successfully applied for many recognition problems in machine learning and computer vision.
9 | Although it can approximate a complex many-to-one function very well when large number of training data is provided, the lack of probabilistic inference of the current supervised deep learning methods makes it difficult to model a complex structured output representations.
10 | In this work, Kihyuk Sohn, Honglak Lee and Xinchen Yan develop a scalable deep conditional generative model for structured output variables using Gaussian latent variables.
11 | The model is trained efficiently in the framework of stochastic gradient variational Bayes, and allows a fast prediction using stochastic feed-forward inference.
12 | They called the model Conditional Variational Auto-encoder (CVAE).
13 |
14 | The CVAE is a conditional directed graphical model whose input observations modulate the prior on Gaussian latent variables that generate the outputs.
15 | It is trained to maximize the conditional marginal log-likelihood.
16 | The authors formulate the variational learning objective of the CVAE in the framework of stochastic gradient variational Bayes (SGVB).
17 | In experiments, they demonstrate the effectiveness of the CVAE in comparison to the deterministic neural network counterparts in generating diverse but realistic output predictions using stochastic inference.
18 | Here, we will implement their proof of concept: an artificial experimental setting for structured output prediction using MNIST database.
19 |
20 | ## The problem
21 | Let's divide each digit image into four quadrants, and take one, two, or three quadrant(s) as an input and the remaining quadrants as an output to be predicted.
22 | The image below shows the case where one quadrant is the input:
23 |
24 |
25 |
26 | Our objective is to **learn a model that can perform probabilistic inference and make diverse predictions from a single input**.
27 | This is because we are not simply modeling a many-to-one function as in classification tasks, but we may need to model a mapping from single input to many possible outputs. One of the limitations of deterministic neural networks is that they generate only a single prediction.
28 | In the example above, the input shows a small part of a digit that might be a three or a five.
29 |
30 | ## Evaluating the results
31 | For qualitative analysis, we visualize the generated output samples in the next figure. As we can see, the baseline NNs can only make a single deterministic prediction, and as a result the output looks blurry and doesn’t look realistic in many cases. In contrast, the samples generated by the CVAE models are more realistic and diverse in shape; sometimes they can even change their identity (digit labels), such as from 3 to 5 or from 4 to 9, and vice versa.
32 |
33 |
34 |
35 | We also provide a quantitative evidence by estimating the marginal conditional log-likelihoods (CLLs) in next table.
36 |
37 | | | 1 quadrant | 2 quadrants | 3 quadrants |
38 | |--------------------|------------|-------------|-------------|
39 | | NN (baseline) | 100.4 | 61.9 | 25.4 |
40 | | CVAE (Monte Carlo) | 71.8 | 51.0 | 24.2 |
41 | | Performance gap | 28.6 | 10.9 | 1.2 |
42 |
43 | We achieved similar results to the ones achieved by the authors in the paper. We trained only for 50 epochs with early stopping patience of 3 epochs; to improve the results, we could leave the algorithm training for longer. Nevertheless, we can observe the same effect shown in the paper: **the estimated CLLs of the CVAE significantly outperforms the baseline NN**.
44 |
45 | See the full code on [Github](https://github.com/ucals/cvae).
46 |
47 | ## IMPORTANT
48 | There are some issue reports when trying to run the code with Pyro versions different than the one in `requirements.txt`.
49 | So, to make sure the code works, the recommended way is to create a clean virtual environment (conda or virtualenv), and running `pip install -r requirements.txt` in this new environment.
50 |
51 | ## References
52 |
53 | [1] `Learning Structured Output Representation using Deep Conditional Generative Models`,
54 | Kihyuk Sohn, Xinchen Yan, Honglak Lee
55 |
56 |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pyro-ppl==1.4.0
2 | pandas==1.5.3
3 | torch==1.13.1
4 | torchvision==0.14.1
5 | matplotlib==3.7.1
6 | scikit-learn==1.2.2
7 |
--------------------------------------------------------------------------------
/src/baseline.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import numpy as np
3 | from pathlib import Path
4 | from tqdm import tqdm
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class BaselineNet(nn.Module):
11 | def __init__(self, hidden_1, hidden_2):
12 | super().__init__()
13 | self.fc1 = nn.Linear(784, hidden_1)
14 | self.fc2 = nn.Linear(hidden_1, hidden_2)
15 | self.fc3 = nn.Linear(hidden_2, 784)
16 | self.relu = nn.ReLU()
17 |
18 | def forward(self, x):
19 | x = x.view(-1, 784)
20 | hidden = self.relu(self.fc1(x))
21 | hidden = self.relu(self.fc2(hidden))
22 | y = torch.sigmoid(self.fc3(hidden))
23 | return y
24 |
25 |
26 | class MaskedBCELoss(nn.Module):
27 | def __init__(self, masked_with=-1):
28 | super().__init__()
29 | self.masked_with = masked_with
30 |
31 | def forward(self, input, target):
32 | target = target.view(input.shape)
33 | loss = F.binary_cross_entropy(input, target, reduction='none')
34 | loss[target == self.masked_with] = 0
35 | return loss.sum()
36 |
37 |
38 | def train(device, dataloaders, dataset_sizes, learning_rate, num_epochs,
39 | early_stop_patience, model_path):
40 |
41 | # Train baseline
42 | baseline_net = BaselineNet(500, 500)
43 | baseline_net.to(device)
44 | optimizer = torch.optim.Adam(baseline_net.parameters(), lr=learning_rate)
45 | criterion = MaskedBCELoss()
46 | best_loss = np.inf
47 | early_stop_count = 0
48 |
49 | for epoch in range(num_epochs):
50 | for phase in ['train', 'val']:
51 | if phase == 'train':
52 | baseline_net.train()
53 | else:
54 | baseline_net.eval()
55 |
56 | running_loss = 0.0
57 | num_preds = 0
58 |
59 | bar = tqdm(dataloaders[phase],
60 | desc='NN Epoch {} {}'.format(epoch, phase).ljust(20))
61 | for i, batch in enumerate(bar):
62 | inputs = batch['input'].to(device)
63 | outputs = batch['output'].to(device)
64 |
65 | optimizer.zero_grad()
66 |
67 | with torch.set_grad_enabled(phase == 'train'):
68 | preds = baseline_net(inputs)
69 | loss = criterion(preds, outputs) / inputs.size(0)
70 | if phase == 'train':
71 | loss.backward()
72 | optimizer.step()
73 |
74 | running_loss += loss.item()
75 | num_preds += 1
76 | if i % 10 == 0:
77 | bar.set_postfix(loss='{:.2f}'.format(running_loss / num_preds),
78 | early_stop_count=early_stop_count)
79 |
80 | epoch_loss = running_loss / dataset_sizes[phase]
81 | # deep copy the model
82 | if phase == 'val':
83 | if epoch_loss < best_loss:
84 | best_loss = epoch_loss
85 | best_model_wts = copy.deepcopy(baseline_net.state_dict())
86 | early_stop_count = 0
87 | else:
88 | early_stop_count += 1
89 |
90 | if early_stop_count >= early_stop_patience:
91 | break
92 |
93 | baseline_net.load_state_dict(best_model_wts)
94 | baseline_net.eval()
95 |
96 | # Save model weights
97 | Path(model_path).parent.mkdir(parents=True, exist_ok=True)
98 | torch.save(baseline_net.state_dict(), model_path)
99 |
100 | return baseline_net
101 |
--------------------------------------------------------------------------------
/src/cvae.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Conditional Variational Auto-encoder\n",
8 | "\n",
9 | "## Introduction\n",
10 | "This tutorial implements [Learning Structured Output Representation using Deep Conditional Generative Models](http://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-generati) paper, which introduced Conditional Variational Auto-encoders in 2015, using Pyro PPL.\n",
11 | "\n",
12 | "Supervised deep learning has been successfully applied for many recognition problems in machine learning and computer vision. Although it can approximate a complex many-to-one function very well when large number of training data is provided, the lack of probabilistic inference of the current supervised deep learning methods makes it difficult to model a complex structured output representations. In this work, Kihyuk Sohn, Honglak Lee and Xinchen Yan develop a scalable deep conditional generative model for structured output variables using Gaussian latent variables. The model is trained efficiently in the framework of stochastic gradient variational Bayes, and allows a fast prediction using stochastic feed-forward inference. They called the model Conditional Variational Auto-encoder (CVAE).\n",
13 | "\n",
14 | "The CVAE is a conditional directed graphical model whose input observations modulate the prior on Gaussian latent variables that generate the outputs. It is trained to maximize the conditional marginal log-likelihood. The authors formulate the variational learning objective of the CVAE in the framework of stochastic gradient variational Bayes (SGVB). In experiments, they demonstrate the effectiveness of the CVAE in comparison to the deterministic neural network counterparts in generating diverse but realistic output predictions using stochastic inference. Here, we will implement their proof of concept: an artificial experimental setting for structured output prediction using MNIST database.\n",
15 | "\n",
16 | "## The problem\n",
17 | "Let's divide each digit image into four quadrants, and take one, two, or three quadrant(s) as an input and the remaining quadrants as an output to be predicted. The image below shows the case where one quadrant is the input:\n",
18 | "\n",
19 | "
\n",
20 | "\n",
21 | "Our objective is to **learn a model that can perform probabilistic inference and make diverse predictions from a single input**. This is because we are not simply modeling a many-to-one function as in classification tasks, but we may need to model a mapping from single input to many possible outputs. One of the limitations of deterministic neural networks is that they generate only a single prediction. In the example above, the input shows a small part of a digit that might be a three or a five. \n",
22 | "\n",
23 | "## Preparing the data\n",
24 | "We use the MNIST dataset; the first step is to prepare it. Depending on how many quadrants we will use as inputs, we will build the datasets and dataloaders, removing the unused pixels with -1:"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "class CVAEMNIST(Dataset):\n",
34 | " def __init__(self, root, train=True, transform=None, download=False):\n",
35 | " self.original = MNIST(root, train=train, download=download)\n",
36 | " self.transform = transform\n",
37 | "\n",
38 | " def __len__(self):\n",
39 | " return len(self.original)\n",
40 | "\n",
41 | " def __getitem__(self, item):\n",
42 | " image, digit = self.original[item]\n",
43 | " sample = {'original': image, 'digit': digit}\n",
44 | " if self.transform:\n",
45 | " sample = self.transform(sample)\n",
46 | "\n",
47 | " return sample\n",
48 | "\n",
49 | "\n",
50 | "class ToTensor:\n",
51 | " def __call__(self, sample):\n",
52 | " sample['original'] = functional.to_tensor(sample['original'])\n",
53 | " sample['digit'] = torch.as_tensor(np.asarray(sample['digit']),\n",
54 | " dtype=torch.int64)\n",
55 | " return sample\n",
56 | "\n",
57 | "\n",
58 | "class MaskImages:\n",
59 | " \"\"\"This torchvision image transformation prepares the MNIST digits to be\n",
60 | " used in the tutorial. Depending on the number of quadrants to be used as\n",
61 | " inputs (1, 2, or 3), the transformation masks the remaining (3, 2, 1)\n",
62 | " quadrant(s) setting their pixels with -1. Additionally, the transformation\n",
63 | " adds the target output in the sample dict as the complementary of the input\n",
64 | " \"\"\"\n",
65 | " def __init__(self, num_quadrant_inputs, mask_with=-1):\n",
66 | " if num_quadrant_inputs <= 0 or num_quadrant_inputs >= 4:\n",
67 | " raise ValueError('Number of quadrants as inputs must be 1, 2 or 3')\n",
68 | " self.num = num_quadrant_inputs\n",
69 | " self.mask_with = mask_with\n",
70 | "\n",
71 | " def __call__(self, sample):\n",
72 | " tensor = sample['original'].squeeze()\n",
73 | " out = tensor.detach().clone()\n",
74 | " h, w = tensor.shape\n",
75 | "\n",
76 | " # removes the bottom left quadrant from the target output\n",
77 | " out[h // 2:, :w // 2] = self.mask_with\n",
78 | " # if num of quadrants to be used as input is 2,\n",
79 | " # also removes the top left quadrant from the target output\n",
80 | " if self.num == 2:\n",
81 | " out[:, :w // 2] = self.mask_with\n",
82 | " # if num of quadrants to be used as input is 3,\n",
83 | " # also removes the top right quadrant from the target output\n",
84 | " if self.num == 3:\n",
85 | " out[:h // 2, :] = self.mask_with\n",
86 | "\n",
87 | " # now, sets the input as complementary\n",
88 | " inp = tensor.clone()\n",
89 | " inp[out != -1] = self.mask_with\n",
90 | "\n",
91 | " sample['input'] = inp\n",
92 | " sample['output'] = out\n",
93 | " return sample\n",
94 | "\n",
95 | "\n",
96 | "def get_data(num_quadrant_inputs, batch_size):\n",
97 | " transforms = Compose([\n",
98 | " ToTensor(),\n",
99 | " MaskImages(num_quadrant_inputs=num_quadrant_inputs)\n",
100 | " ])\n",
101 | " datasets, dataloaders, dataset_sizes = {}, {}, {}\n",
102 | " for mode in ['train', 'val']:\n",
103 | " datasets[mode] = CVAEMNIST(\n",
104 | " '../data',\n",
105 | " download=True,\n",
106 | " transform=transforms,\n",
107 | " train=mode == 'train'\n",
108 | " )\n",
109 | " dataloaders[mode] = DataLoader(\n",
110 | " datasets[mode],\n",
111 | " batch_size=batch_size,\n",
112 | " shuffle=mode == 'train',\n",
113 | " num_workers=0\n",
114 | " )\n",
115 | " dataset_sizes[mode] = len(datasets[mode])\n",
116 | "\n",
117 | " return datasets, dataloaders, dataset_sizes"
118 | ]
119 | },
120 | {
121 | "cell_type": "markdown",
122 | "metadata": {},
123 | "source": [
124 | "## Baseline: Deterministic Neural Network\n",
125 | "Before we dive into the CVAE implementation, let's code the baseline model. It is a straightforward implementation:"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": null,
131 | "metadata": {},
132 | "outputs": [],
133 | "source": [
134 | "class BaselineNet(nn.Module):\n",
135 | " def __init__(self, hidden_1, hidden_2):\n",
136 | " super().__init__()\n",
137 | " self.fc1 = nn.Linear(784, hidden_1)\n",
138 | " self.fc2 = nn.Linear(hidden_1, hidden_2)\n",
139 | " self.fc3 = nn.Linear(hidden_2, 784)\n",
140 | " self.relu = nn.ReLU()\n",
141 | "\n",
142 | " def forward(self, x):\n",
143 | " x = x.view(-1, 784)\n",
144 | " hidden = self.relu(self.fc1(x))\n",
145 | " hidden = self.relu(self.fc2(hidden))\n",
146 | " y = torch.sigmoid(self.fc3(hidden))\n",
147 | " return y"
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "metadata": {},
153 | "source": [
154 | "In the paper, the authors compare the baseline NN with the proposed CVAE by comparing the negative (Conditional) Log Likelihood (CLL), averaged by image in the validation set. Thanks to PyTorch, computing the CLL is equivalent to computing the Binary Cross Entropy Loss using as input a signal passed through a Sigmoid layer. The code below does a small adjustment to leverage this: it only computes the loss in the pixels not masked with -1:"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {},
161 | "outputs": [],
162 | "source": [
163 | "class MaskedBCELoss(nn.Module):\n",
164 | " def __init__(self, masked_with=-1):\n",
165 | " super().__init__()\n",
166 | " self.masked_with = masked_with\n",
167 | "\n",
168 | " def forward(self, input, target):\n",
169 | " target = target.view(input.shape)\n",
170 | " loss = F.binary_cross_entropy(input, target, reduction='none')\n",
171 | " loss[target == self.masked_with] = 0\n",
172 | " return loss.sum()"
173 | ]
174 | },
175 | {
176 | "cell_type": "markdown",
177 | "metadata": {},
178 | "source": [
179 | "The training is very straightforward. We use 500 neurons in each hidden layer, Adam optimizer with `1e-3` learning rate, and early stopping. Please check the [Github repo](https://github.com/pyro-ppl/pyro/blob/dev/examples/cvae) for the full implementation.\n",
180 | "\n",
181 | "## Deep Conditional Generative Models for Structured Output Prediction\n",
182 | "As illustrated in the image below, there are three types of variables in a deep conditional generative model (CGM): input variables $\\bf x$, output variables $\\bf y$, and latent variables $\\bf z$. The conditional generative process of the model is given in (b) as follows: for given observation $\\bf x$, $\\bf z$ is drawn from the prior distribution $p_{\\theta}({\\bf z} | {\\bf x})$, and the output $\\bf y$ is generated from the distribution $p_{\\theta}({\\bf y} | {\\bf x, z})$. Compared to the baseline NN (a), the latent variables $\\bf z$ allow for modeling multiple modes in conditional distribution of output variables $\\bf y$ given input $\\bf x$, making the proposed CGM suitable for modeling one-to-many mapping.\n",
183 | "\n",
184 | "\n",
185 | "
\n",
186 | "\n",
187 | "Deep CGMs are trained to maximize the conditional marginal log-likelihood. Often the objective function is intractable, and we apply the SGVB framework to train the model. The empirical lower bound is written as:\n",
188 | "\n",
189 | "$$ \\tilde{\\mathcal{L}}_{\\text{CVAE}}(x, y; \\theta, \\phi) = -KL(q_{\\phi}(z | x, y) || p_{\\theta}(z | x)) + \\frac{1}{L}\\sum_{l=1}^{L}\\log p_{\\theta}(y | x, z^{(l)}) $$\n",
190 | "\n",
191 | "where $\\bf z^{(l)}$ is a Gaussian latent variable product, and $L$ is the number of samples (or particles in Pyro nomenclature).\n",
192 | "We call this model conditional variational auto-encoder (CVAE). The CVAE is composed of multiple MLPs, such as **recognition network** $q_{\\phi}({\\bf z} | \\bf{x, y})$, **(conditional) prior network** $p_{\\theta}(\\bf{z} | \\bf{x})$, and **generation network** $p_{\\theta}(\\bf{y} | \\bf{x, z})$. In designing the network architecture, we build the network components of the CVAE **on top of the baseline NN**. Specifically, as shown in (d) above, not only the direct input $\\bf x$, but also the initial guess $\\hat{y}$ made by the NN are fed into the prior network. \n",
193 | "\n",
194 | "Pyro makes it really easy to translate this architecture into code. The recognition network and the (conditional) prior network are encoders from the traditional VAE setting, while the generation network is the decoder:"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": null,
200 | "metadata": {},
201 | "outputs": [],
202 | "source": [
203 | "class Encoder(nn.Module):\n",
204 | " def __init__(self, z_dim, hidden_1, hidden_2):\n",
205 | " super().__init__()\n",
206 | " self.fc1 = nn.Linear(784, hidden_1)\n",
207 | " self.fc2 = nn.Linear(hidden_1, hidden_2)\n",
208 | " self.fc31 = nn.Linear(hidden_2, z_dim)\n",
209 | " self.fc32 = nn.Linear(hidden_2, z_dim)\n",
210 | " self.relu = nn.ReLU()\n",
211 | "\n",
212 | " def forward(self, x, y):\n",
213 | " # put x and y together in the same image for simplification\n",
214 | " xc = x.clone()\n",
215 | " xc[x == -1] = y[x == -1]\n",
216 | " xc = xc.view(-1, 784)\n",
217 | " # then compute the hidden units\n",
218 | " hidden = self.relu(self.fc1(xc))\n",
219 | " hidden = self.relu(self.fc2(hidden))\n",
220 | " # then return a mean vector and a (positive) square root covariance\n",
221 | " # each of size batch_size x z_dim\n",
222 | " z_loc = self.fc31(hidden)\n",
223 | " z_scale = torch.exp(self.fc32(hidden))\n",
224 | " return z_loc, z_scale\n",
225 | "\n",
226 | "\n",
227 | "class Decoder(nn.Module):\n",
228 | " def __init__(self, z_dim, hidden_1, hidden_2):\n",
229 | " super().__init__()\n",
230 | " self.fc1 = nn.Linear(z_dim, hidden_1)\n",
231 | " self.fc2 = nn.Linear(hidden_1, hidden_2)\n",
232 | " self.fc3 = nn.Linear(hidden_2, 784)\n",
233 | " self.relu = nn.ReLU()\n",
234 | "\n",
235 | " def forward(self, z):\n",
236 | " y = self.relu(self.fc1(z))\n",
237 | " y = self.relu(self.fc2(y))\n",
238 | " y = torch.sigmoid(self.fc3(y))\n",
239 | " return y\n",
240 | "\n",
241 | "\n",
242 | "class CVAE(nn.Module):\n",
243 | " def __init__(self, z_dim, hidden_1, hidden_2, pre_trained_baseline_net):\n",
244 | " super().__init__()\n",
245 | " # The CVAE is composed of multiple MLPs, such as recognition network\n",
246 | " # qφ(z|x, y), (conditional) prior network pθ(z|x), and generation\n",
247 | " # network pθ(y|x, z). Also, CVAE is built on top of the NN: not only\n",
248 | " # the direct input x, but also the initial guess y_hat made by the NN\n",
249 | " # are fed into the prior network.\n",
250 | " self.baseline_net = pre_trained_baseline_net\n",
251 | " self.prior_net = Encoder(z_dim, hidden_1, hidden_2)\n",
252 | " self.generation_net = Decoder(z_dim, hidden_1, hidden_2)\n",
253 | " self.recognition_net = Encoder(z_dim, hidden_1, hidden_2)\n",
254 | "\n",
255 | " def model(self, xs, ys=None):\n",
256 | " # register this pytorch module and all of its sub-modules with pyro\n",
257 | " pyro.module(\"generation_net\", self)\n",
258 | " batch_size = xs.shape[0]\n",
259 | " with pyro.plate(\"data\"):\n",
260 | "\n",
261 | " # Prior network uses the baseline predictions as initial guess.\n",
262 | " # This is the generative process with recurrent connection\n",
263 | " with torch.no_grad():\n",
264 | " # this ensures the training process does not change the\n",
265 | " # baseline network\n",
266 | " y_hat = self.baseline_net(xs).view(xs.shape)\n",
267 | "\n",
268 | " # sample the handwriting style from the prior distribution, which is\n",
269 | " # modulated by the input xs.\n",
270 | " prior_loc, prior_scale = self.prior_net(xs, y_hat)\n",
271 | " zs = pyro.sample('z', dist.Normal(prior_loc, prior_scale).to_event(1))\n",
272 | "\n",
273 | " # the output y is generated from the distribution pθ(y|x, z)\n",
274 | " loc = self.generation_net(zs)\n",
275 | "\n",
276 | " if ys is not None:\n",
277 | " # In training, we will only sample in the masked image\n",
278 | " mask_loc = loc[(xs == -1).view(-1, 784)].view(batch_size, -1)\n",
279 | " mask_ys = ys[xs == -1].view(batch_size, -1)\n",
280 | " pyro.sample('y', dist.Bernoulli(mask_loc).to_event(1), obs=mask_ys)\n",
281 | " else:\n",
282 | " # In testing, no need to sample: the output is already a\n",
283 | " # probability in [0, 1] range, which better represent pixel\n",
284 | " # values considering grayscale. If we sample, we will force\n",
285 | " # each pixel to be either 0 or 1, killing the grayscale\n",
286 | " pyro.deterministic('y', loc.detach())\n",
287 | "\n",
288 | " # return the loc so we can visualize it later\n",
289 | " return loc\n",
290 | "\n",
291 | " def guide(self, xs, ys=None):\n",
292 | " with pyro.plate(\"data\"):\n",
293 | " if ys is None:\n",
294 | " # at inference time, ys is not provided. In that case,\n",
295 | " # the model uses the prior network\n",
296 | " y_hat = self.baseline_net(xs).view(xs.shape)\n",
297 | " loc, scale = self.prior_net(xs, y_hat)\n",
298 | " else:\n",
299 | " # at training time, uses the variational distribution\n",
300 | " # q(z|x,y) = normal(loc(x,y),scale(x,y))\n",
301 | " loc, scale = self.recognition_net(xs, ys)\n",
302 | "\n",
303 | " pyro.sample(\"z\", dist.Normal(loc, scale).to_event(1))\n",
304 | "\n",
305 | " def save(self, model_path):\n",
306 | " torch.save({\n",
307 | " 'prior': self.prior_net.state_dict(),\n",
308 | " 'generation': self.generation_net.state_dict(),\n",
309 | " 'recognition': self.recognition_net.state_dict()\n",
310 | " }, model_path)\n",
311 | "\n",
312 | " def load(self, model_path, map_location=None):\n",
313 | " net_weights = torch.load(model_path, map_location=map_location)\n",
314 | " self.prior_net.load_state_dict(net_weights['prior'])\n",
315 | " self.generation_net.load_state_dict(net_weights['generation'])\n",
316 | " self.recognition_net.load_state_dict(net_weights['recognition'])\n",
317 | " self.prior_net.eval()\n",
318 | " self.generation_net.eval()\n",
319 | " self.recognition_net.eval()"
320 | ]
321 | },
322 | {
323 | "cell_type": "markdown",
324 | "metadata": {},
325 | "source": [
326 | "## Evaluating the results\n",
327 | "For qualitative analysis, we visualize the generated output samples in the next figure. As we can see, the baseline NNs can only make a single deterministic prediction, and as a result the output looks blurry and doesn’t look realistic in many cases. In contrast, the samples generated by the CVAE models are more realistic and diverse in shape; sometimes they can even change their identity (digit labels), such as from 3 to 5 or from 4 to 9, and vice versa.\n",
328 | "\n",
329 | "
\n",
330 | "\n",
331 | "We also provide a quantitative evidence by estimating the marginal conditional log-likelihoods (CLLs) in next table. \n",
332 | "\n",
333 | "| | 1 quadrant | 2 quadrants | 3 quadrants |\n",
334 | "|--------------------|------------|-------------|-------------|\n",
335 | "| NN (baseline) | 100.4 | 61.9 | 25.4 |\n",
336 | "| CVAE (Monte Carlo) | 71.8 | 51.0 | 24.2 |\n",
337 | "| Performance gap | 28.6 | 10.9 | 1.2 |\n",
338 | "\n",
339 | "We achieved similar results to the ones achieved by the authors in the paper. We trained only for 50 epochs with early stopping patience of 3 epochs; to improve the results, we could leave the algorithm training for longer. Nevertheless, we can observe the same effect shown in the paper: **the estimated CLLs of the CVAE significantly outperforms the baseline NN**.\n",
340 | "\n",
341 | "See the full code on [Github](https://github.com/pyro-ppl/pyro/blob/dev/examples/cvae).\n",
342 | "\n",
343 | "## References\n",
344 | "\n",
345 | "[1] `Learning Structured Output Representation using Deep Conditional Generative Models`,
\n",
346 | "Kihyuk Sohn, Xinchen Yan, Honglak Lee"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": null,
352 | "metadata": {},
353 | "outputs": [],
354 | "source": []
355 | }
356 | ],
357 | "metadata": {
358 | "kernelspec": {
359 | "display_name": "Python 3",
360 | "language": "python",
361 | "name": "python3"
362 | },
363 | "language_info": {
364 | "codemirror_mode": {
365 | "name": "ipython",
366 | "version": 3
367 | },
368 | "file_extension": ".py",
369 | "mimetype": "text/x-python",
370 | "name": "python",
371 | "nbconvert_exporter": "python",
372 | "pygments_lexer": "ipython3",
373 | "version": "3.7.6"
374 | }
375 | },
376 | "nbformat": 4,
377 | "nbformat_minor": 4
378 | }
379 |
--------------------------------------------------------------------------------
/src/cvae.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from pathlib import Path
4 | import pyro
5 | import pyro.distributions as dist
6 | from pyro.infer import SVI, Trace_ELBO, Predictive
7 | import torch
8 | import torch.nn as nn
9 | from tqdm import tqdm
10 | from mnist import get_val_images
11 |
12 |
13 | class Encoder(nn.Module):
14 | def __init__(self, z_dim, hidden_1, hidden_2):
15 | super().__init__()
16 | self.fc1 = nn.Linear(784, hidden_1)
17 | self.fc2 = nn.Linear(hidden_1, hidden_2)
18 | self.fc31 = nn.Linear(hidden_2, z_dim)
19 | self.fc32 = nn.Linear(hidden_2, z_dim)
20 | self.relu = nn.ReLU()
21 |
22 | def forward(self, x, y):
23 | # put x and y together in the same image for simplification
24 | xc = x.clone()
25 | xc[x == -1] = y[x == -1]
26 | xc = xc.view(-1, 784)
27 | # then compute the hidden units
28 | hidden = self.relu(self.fc1(xc))
29 | hidden = self.relu(self.fc2(hidden))
30 | # then return a mean vector and a (positive) square root covariance
31 | # each of size batch_size x z_dim
32 | z_loc = self.fc31(hidden)
33 | z_scale = torch.exp(self.fc32(hidden))
34 | return z_loc, z_scale
35 |
36 |
37 | class Decoder(nn.Module):
38 | def __init__(self, z_dim, hidden_1, hidden_2):
39 | super().__init__()
40 | self.fc1 = nn.Linear(z_dim, hidden_1)
41 | self.fc2 = nn.Linear(hidden_1, hidden_2)
42 | self.fc3 = nn.Linear(hidden_2, 784)
43 | self.relu = nn.ReLU()
44 |
45 | def forward(self, z):
46 | y = self.relu(self.fc1(z))
47 | y = self.relu(self.fc2(y))
48 | y = torch.sigmoid(self.fc3(y))
49 | return y
50 |
51 |
52 | class CVAE(nn.Module):
53 | def __init__(self, z_dim, hidden_1, hidden_2, pre_trained_baseline_net):
54 | super().__init__()
55 | # The CVAE is composed of multiple MLPs, such as recognition network
56 | # qφ(z|x, y), (conditional) prior network pθ(z|x), and generation
57 | # network pθ(y|x, z). Also, CVAE is built on top of the NN: not only
58 | # the direct input x, but also the initial guess y_hat made by the NN
59 | # are fed into the prior network.
60 | self.baseline_net = pre_trained_baseline_net
61 | self.prior_net = Encoder(z_dim, hidden_1, hidden_2)
62 | self.generation_net = Decoder(z_dim, hidden_1, hidden_2)
63 | self.recognition_net = Encoder(z_dim, hidden_1, hidden_2)
64 |
65 | def model(self, xs, ys=None):
66 | # register this pytorch module and all of its sub-modules with pyro
67 | pyro.module("generation_net", self)
68 | batch_size = xs.shape[0]
69 | with pyro.plate("data"):
70 |
71 | # Prior network uses the baseline predictions as initial guess.
72 | # This is the generative process with recurrent connection
73 | with torch.no_grad():
74 | # this ensures the training process does not change the
75 | # baseline network
76 | y_hat = self.baseline_net(xs).view(xs.shape)
77 |
78 | # sample the handwriting style from the prior distribution, which is
79 | # modulated by the input xs.
80 | prior_loc, prior_scale = self.prior_net(xs, y_hat)
81 | zs = pyro.sample('z', dist.Normal(prior_loc, prior_scale).to_event(1))
82 |
83 | # the output y is generated from the distribution pθ(y|x, z)
84 | loc = self.generation_net(zs)
85 |
86 | if ys is not None:
87 | # In training, we will only sample in the masked image
88 | mask_loc = loc[(xs == -1).view(-1, 784)].view(batch_size, -1)
89 | mask_ys = ys[xs == -1].view(batch_size, -1)
90 | pyro.sample('y', dist.Bernoulli(mask_loc).to_event(1), obs=mask_ys)
91 | else:
92 | # In testing, no need to sample: the output is already a
93 | # probability in [0, 1] range, which better represent pixel
94 | # values considering grayscale. If we sample, we will force
95 | # each pixel to be either 0 or 1, killing the grayscale
96 | pyro.deterministic('y', loc.detach())
97 |
98 | # return the loc so we can visualize it later
99 | return loc
100 |
101 | def guide(self, xs, ys=None):
102 | with pyro.plate("data"):
103 | if ys is None:
104 | # at inference time, ys is not provided. In that case,
105 | # the model uses the prior network
106 | y_hat = self.baseline_net(xs).view(xs.shape)
107 | loc, scale = self.prior_net(xs, y_hat)
108 | else:
109 | # at training time, uses the variational distribution
110 | # q(z|x,y) = normal(loc(x,y),scale(x,y))
111 | loc, scale = self.recognition_net(xs, ys)
112 |
113 | pyro.sample("z", dist.Normal(loc, scale).to_event(1))
114 |
115 | def save(self, model_path):
116 | torch.save({
117 | 'prior': self.prior_net.state_dict(),
118 | 'generation': self.generation_net.state_dict(),
119 | 'recognition': self.recognition_net.state_dict()
120 | }, model_path)
121 |
122 | def load(self, model_path, map_location=None):
123 | net_weights = torch.load(model_path, map_location=map_location)
124 | self.prior_net.load_state_dict(net_weights['prior'])
125 | self.generation_net.load_state_dict(net_weights['generation'])
126 | self.recognition_net.load_state_dict(net_weights['recognition'])
127 | self.prior_net.eval()
128 | self.generation_net.eval()
129 | self.recognition_net.eval()
130 |
131 |
132 | def train(device, dataloaders, dataset_sizes, learning_rate, num_epochs,
133 | early_stop_patience, model_path, pre_trained_baseline_net):
134 |
135 | # clear param store
136 | pyro.clear_param_store()
137 |
138 | cvae_net = CVAE(200, 500, 500, pre_trained_baseline_net)
139 | cvae_net.to(device)
140 | optimizer = pyro.optim.Adam({"lr": learning_rate})
141 | svi = SVI(cvae_net.model, cvae_net.guide, optimizer, loss=Trace_ELBO())
142 |
143 | best_loss = np.inf
144 | early_stop_count = 0
145 | Path(model_path).parent.mkdir(parents=True, exist_ok=True)
146 |
147 | # to track evolution
148 | val_inp, digits = get_val_images(num_quadrant_inputs=1,
149 | num_images=30, shuffle=False)
150 | val_inp = val_inp.to(device)
151 | samples = []
152 | losses = []
153 |
154 | for epoch in range(num_epochs):
155 | # Each epoch has a training and validation phase
156 | for phase in ['train', 'val']:
157 | running_loss = 0.0
158 |
159 | # Iterate over data.
160 | bar = tqdm(dataloaders[phase],
161 | desc='CVAE Epoch {} {}'.format(epoch, phase).ljust(20))
162 | for i, batch in enumerate(bar):
163 | inputs = batch['input'].to(device)
164 | outputs = batch['output'].to(device)
165 |
166 | if phase == 'train':
167 | loss = svi.step(inputs, outputs) / inputs.size(0)
168 | else:
169 | loss = svi.evaluate_loss(inputs, outputs) / inputs.size(0)
170 |
171 | # statistics
172 | running_loss += loss
173 | if i % 10 == 0:
174 | bar.set_postfix(loss='{:.2f}'.format(loss),
175 | early_stop_count=early_stop_count)
176 |
177 | # track evolution
178 | if phase == 'train':
179 | df = pd.DataFrame(columns=['epoch', 'loss'])
180 | df.loc[0] = [epoch + float(i) / len(dataloaders[phase]), loss]
181 | losses.append(df)
182 | if i % 47 == 0: # every 10% of training (469)
183 | dfs = predict_samples(
184 | val_inp, digits, cvae_net,
185 | epoch + float(i) / len(dataloaders[phase]),
186 | )
187 | samples.append(dfs)
188 |
189 | epoch_loss = running_loss / dataset_sizes[phase]
190 | # deep copy the model
191 | if phase == 'val':
192 | if epoch_loss < best_loss:
193 | best_loss = epoch_loss
194 | cvae_net.save(model_path)
195 | early_stop_count = 0
196 | else:
197 | early_stop_count += 1
198 |
199 | if early_stop_count >= early_stop_patience:
200 | break
201 |
202 | # Save model weights
203 | cvae_net.load(model_path)
204 |
205 | # record evolution
206 | samples = pd.concat(samples, axis=0, ignore_index=True)
207 | samples.to_csv('samples.csv', index=False)
208 |
209 | losses = pd.concat(losses, axis=0, ignore_index=True)
210 | losses.to_csv('losses.csv', index=False)
211 |
212 | return cvae_net
213 |
214 |
215 | def predict_samples(inputs, digits, pre_trained_cvae, epoch_frac):
216 | predictive = Predictive(pre_trained_cvae.model,
217 | guide=pre_trained_cvae.guide,
218 | num_samples=1)
219 | preds = predictive(inputs)
220 | y_loc = preds['y'].squeeze().detach().cpu().numpy()
221 | dfs = pd.DataFrame(data=y_loc)
222 | dfs['digit'] = digits.numpy()
223 | dfs['epoch'] = epoch_frac
224 | return dfs
225 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pandas as pd
3 | import pyro
4 | import torch
5 | import baseline
6 | import cvae
7 | from util import get_data, visualize, generate_table
8 |
9 |
10 | def main(args):
11 | device = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda
12 | else "cpu")
13 | results = []
14 | columns = []
15 |
16 | for num_quadrant_inputs in args.num_quadrant_inputs:
17 | # adds an s in case of plural quadrants
18 | maybes = 's' if num_quadrant_inputs > 1 else ''
19 |
20 | print('Training with {} quadrant{} as input...'
21 | .format(num_quadrant_inputs, maybes))
22 |
23 | # Dataset
24 | datasets, dataloaders, dataset_sizes = get_data(
25 | num_quadrant_inputs=num_quadrant_inputs,
26 | batch_size=128
27 | )
28 |
29 | # Train baseline
30 | baseline_net = baseline.train(
31 | device=device,
32 | dataloaders=dataloaders,
33 | dataset_sizes=dataset_sizes,
34 | learning_rate=args.learning_rate,
35 | num_epochs=args.num_epochs,
36 | early_stop_patience=args.early_stop_patience,
37 | model_path='baseline_net_q{}.pth'.format(num_quadrant_inputs)
38 | )
39 |
40 | # Train CVAE
41 | cvae_net = cvae.train(
42 | device=device,
43 | dataloaders=dataloaders,
44 | dataset_sizes=dataset_sizes,
45 | learning_rate=args.learning_rate,
46 | num_epochs=args.num_epochs,
47 | early_stop_patience=args.early_stop_patience,
48 | model_path='cvae_net_q{}.pth'.format(num_quadrant_inputs),
49 | pre_trained_baseline_net=baseline_net
50 | )
51 |
52 | # Visualize conditional predictions
53 | visualize(
54 | device=device,
55 | num_quadrant_inputs=num_quadrant_inputs,
56 | pre_trained_baseline=baseline_net,
57 | pre_trained_cvae=cvae_net,
58 | num_images=args.num_images,
59 | num_samples=args.num_samples,
60 | image_path='cvae_plot_q{}.png'.format(num_quadrant_inputs)
61 | )
62 |
63 | # Retrieve conditional log likelihood
64 | df = generate_table(
65 | device=device,
66 | num_quadrant_inputs=num_quadrant_inputs,
67 | pre_trained_baseline=baseline_net,
68 | pre_trained_cvae=cvae_net,
69 | num_particles=args.num_particles,
70 | col_name='{} quadrant{}'.format(num_quadrant_inputs, maybes)
71 | )
72 | results.append(df)
73 | columns.append('{} quadrant{}'.format(num_quadrant_inputs, maybes))
74 |
75 | results = pd.concat(results, axis=1, ignore_index=True)
76 | results.columns = columns
77 | results.loc['Performance gap', :] = results.iloc[0, :] - results.iloc[1, :]
78 | results.to_csv('results.csv')
79 |
80 |
81 | if __name__ == '__main__':
82 | assert pyro.__version__.startswith('1.4.0')
83 | # parse command line arguments
84 | parser = argparse.ArgumentParser(description="parse args")
85 | parser.add_argument('-nq', '--num-quadrant-inputs', metavar='N', type=int,
86 | nargs='+', default=[1, 2, 3],
87 | help='num of quadrants to use as inputs')
88 | parser.add_argument('-n', '--num-epochs', default=101, type=int,
89 | help='number of training epochs')
90 | parser.add_argument('-esp', '--early-stop-patience', default=3, type=int,
91 | help='early stop patience')
92 | parser.add_argument('-lr', '--learning-rate', default=1.0e-3, type=float,
93 | help='learning rate')
94 | parser.add_argument('--cuda', action='store_true', default=False,
95 | help='whether to use cuda')
96 | parser.add_argument('-vi', '--num-images', default=10, type=int,
97 | help='number of images to visualize')
98 | parser.add_argument('-vs', '--num-samples', default=10, type=int,
99 | help='number of samples to visualize per image')
100 | parser.add_argument('-p', '--num-particles', default=10, type=int,
101 | help='n of particles to estimate logpθ(y|x,z) in ELBO')
102 | args = parser.parse_args()
103 |
104 | main(args)
105 |
--------------------------------------------------------------------------------
/src/mnist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import Dataset, DataLoader
4 | from torchvision.datasets import MNIST
5 | from torchvision.transforms import Compose, functional
6 |
7 |
8 | class CVAEMNIST(Dataset):
9 | def __init__(self, root, train=True, transform=None, download=False):
10 | self.original = MNIST(root, train=train, download=download)
11 | self.transform = transform
12 |
13 | def __len__(self):
14 | return len(self.original)
15 |
16 | def __getitem__(self, item):
17 | image, digit = self.original[item]
18 | sample = {'original': image, 'digit': digit}
19 | if self.transform:
20 | sample = self.transform(sample)
21 |
22 | return sample
23 |
24 |
25 | class ToTensor:
26 | def __call__(self, sample):
27 | sample['original'] = functional.to_tensor(sample['original'])
28 | sample['digit'] = torch.as_tensor(np.asarray(sample['digit']),
29 | dtype=torch.int64)
30 | return sample
31 |
32 |
33 | class MaskImages:
34 | """This torchvision image transformation prepares the MNIST digits to be
35 | used in the tutorial. Depending on the number of quadrants to be used as
36 | inputs (1, 2, or 3), the transformation masks the remaining (3, 2, 1)
37 | quadrant(s) setting their pixels with -1. Additionally, the transformation
38 | adds the target output in the sample dict as the complementary of the input
39 | """
40 | def __init__(self, num_quadrant_inputs, mask_with=-1):
41 | if num_quadrant_inputs <= 0 or num_quadrant_inputs >= 4:
42 | raise ValueError('Number of quadrants as inputs must be 1, 2 or 3')
43 | self.num = num_quadrant_inputs
44 | self.mask_with = mask_with
45 |
46 | def __call__(self, sample):
47 | tensor = sample['original'].squeeze()
48 | out = tensor.detach().clone()
49 | h, w = tensor.shape
50 |
51 | # removes the bottom left quadrant from the target output
52 | out[h // 2:, :w // 2] = self.mask_with
53 | # if num of quadrants to be used as input is 2,
54 | # also removes the top left quadrant from the target output
55 | if self.num == 2:
56 | out[:, :w // 2] = self.mask_with
57 | # if num of quadrants to be used as input is 3,
58 | # also removes the top right quadrant from the target output
59 | if self.num == 3:
60 | out[:h // 2, :] = self.mask_with
61 |
62 | # now, sets the input as complementary
63 | inp = tensor.clone()
64 | inp[out != -1] = self.mask_with
65 |
66 | sample['input'] = inp
67 | sample['output'] = out
68 | return sample
69 |
70 |
71 | def get_data(num_quadrant_inputs, batch_size):
72 | transforms = Compose([
73 | ToTensor(),
74 | MaskImages(num_quadrant_inputs=num_quadrant_inputs)
75 | ])
76 | datasets, dataloaders, dataset_sizes = {}, {}, {}
77 | for mode in ['train', 'val']:
78 | datasets[mode] = CVAEMNIST(
79 | '../data',
80 | download=True,
81 | transform=transforms,
82 | train=mode == 'train'
83 | )
84 | dataloaders[mode] = DataLoader(
85 | datasets[mode],
86 | batch_size=batch_size,
87 | shuffle=mode == 'train',
88 | num_workers=0
89 | )
90 | dataset_sizes[mode] = len(datasets[mode])
91 |
92 | return datasets, dataloaders, dataset_sizes
93 |
94 |
95 | def get_val_images(num_quadrant_inputs, num_images, shuffle):
96 | datasets, _, dataset_sizes = get_data(
97 | num_quadrant_inputs=num_quadrant_inputs,
98 | batch_size=num_images
99 | )
100 | dataloader = DataLoader(datasets['val'], batch_size=num_images,
101 | shuffle=shuffle)
102 |
103 | batch = next(iter(dataloader))
104 | inputs = batch['input']
105 | digits = batch['digit']
106 | return inputs, digits
107 |
--------------------------------------------------------------------------------
/src/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import pandas as pd
4 | from pathlib import Path
5 | from pyro.infer import Predictive, Trace_ELBO
6 | from sklearn.manifold import TSNE
7 | import torch
8 | from torch.utils.data import DataLoader
9 | from torchvision.utils import make_grid
10 | from tqdm import tqdm
11 | from baseline import MaskedBCELoss, BaselineNet
12 | from mnist import get_data, get_val_images
13 | from cvae import CVAE
14 |
15 |
16 | def imshow(inp, image_path=None):
17 | inp = inp.cpu().numpy().transpose((1, 2, 0))
18 | space = np.ones((inp.shape[0], 50, inp.shape[2]))
19 | inp = np.concatenate([space, inp], axis=1)
20 |
21 | ax = plt.axes(frameon=False, xticks=[], yticks=[])
22 | ax.text(0, 23, 'Inputs:')
23 | ax.text(0, 23 + 28 + 3, 'Truth:')
24 | ax.text(0, 23 + (28 + 3) * 2, 'NN:')
25 | ax.text(0, 23 + (28 + 3) * 3, 'CVAE:')
26 | ax.imshow(inp)
27 |
28 | if image_path is not None:
29 | Path(image_path).parent.mkdir(parents=True, exist_ok=True)
30 | plt.savefig(image_path, bbox_inches='tight', pad_inches=0.1)
31 | else:
32 | plt.show()
33 |
34 | plt.clf()
35 |
36 |
37 | def visualize(device, num_quadrant_inputs, pre_trained_baseline,
38 | pre_trained_cvae, num_images, num_samples, image_path=None):
39 |
40 | # Load sample random data
41 | datasets, _, dataset_sizes = get_data(
42 | num_quadrant_inputs=num_quadrant_inputs,
43 | batch_size=num_images
44 | )
45 | dataloader = DataLoader(datasets['val'], batch_size=num_images, shuffle=True)
46 |
47 | batch = next(iter(dataloader))
48 | inputs = batch['input'].to(device)
49 | outputs = batch['output'].to(device)
50 | originals = batch['original'].to(device)
51 |
52 | # Make predictions
53 | with torch.no_grad():
54 | baseline_preds = pre_trained_baseline(inputs).view(outputs.shape)
55 |
56 | predictive = Predictive(pre_trained_cvae.model,
57 | guide=pre_trained_cvae.guide,
58 | num_samples=num_samples)
59 | cvae_preds = predictive(inputs)['y'].view(num_samples, num_images, 28, 28)
60 |
61 | # Predictions are only made in the pixels not masked. This completes
62 | # the input quadrant with the prediction for the missing quadrants, for
63 | # visualization purpose
64 | baseline_preds[outputs == -1] = inputs[outputs == -1]
65 | for i in range(cvae_preds.shape[0]):
66 | cvae_preds[i][outputs == -1] = inputs[outputs == -1]
67 |
68 | # adjust tensor sizes
69 | inputs = inputs.unsqueeze(1)
70 | inputs[inputs == -1] = 1
71 | baseline_preds = baseline_preds.unsqueeze(1)
72 | cvae_preds = cvae_preds.view(-1, 28, 28).unsqueeze(1)
73 |
74 | # make grids
75 | inputs_tensor = make_grid(inputs, nrow=num_images, padding=0)
76 | originals_tensor = make_grid(originals, nrow=num_images, padding=0)
77 | separator_tensor = torch.ones((3, 5, originals_tensor.shape[-1])).to(device)
78 | baseline_tensor = make_grid(baseline_preds, nrow=num_images, padding=0)
79 | cvae_tensor = make_grid(cvae_preds, nrow=num_images, padding=0)
80 |
81 | # add vertical and horizontal lines
82 | for tensor in [originals_tensor, baseline_tensor, cvae_tensor]:
83 | for i in range(num_images - 1):
84 | tensor[:, :, (i + 1) * 28] = 0.3
85 |
86 | for i in range(num_samples - 1):
87 | cvae_tensor[:, (i + 1) * 28, :] = 0.3
88 |
89 | # concatenate all tensors
90 | grid_tensor = torch.cat([inputs_tensor, separator_tensor, originals_tensor,
91 | separator_tensor, baseline_tensor,
92 | separator_tensor, cvae_tensor], dim=1)
93 | # plot tensors
94 | imshow(grid_tensor, image_path=image_path)
95 |
96 |
97 | def generate_table(device, num_quadrant_inputs, pre_trained_baseline,
98 | pre_trained_cvae, num_particles, col_name):
99 |
100 | # Load sample random data
101 | datasets, dataloaders, dataset_sizes = get_data(
102 | num_quadrant_inputs=num_quadrant_inputs,
103 | batch_size=32
104 | )
105 |
106 | # Load sample data
107 | criterion = MaskedBCELoss()
108 | loss_fn = Trace_ELBO(num_particles=num_particles).differentiable_loss
109 |
110 | baseline_cll = 0.0
111 | cvae_mc_cll = 0.0
112 | num_preds = 0
113 |
114 | df = pd.DataFrame(index=['NN (baseline)', 'CVAE (Monte Carlo)'],
115 | columns=[col_name])
116 |
117 | # Iterate over data.
118 | bar = tqdm(dataloaders['val'], desc='Generating predictions'.ljust(20))
119 | for batch in bar:
120 | inputs = batch['input'].to(device)
121 | outputs = batch['output'].to(device)
122 | num_preds += 1
123 |
124 | # Compute negative log likelihood for the baseline NN
125 | with torch.no_grad():
126 | preds = pre_trained_baseline(inputs)
127 | baseline_cll += criterion(preds, outputs).item() / inputs.size(0)
128 |
129 | # Compute the negative conditional log likelihood for the CVAE
130 | cvae_mc_cll += loss_fn(pre_trained_cvae.model,
131 | pre_trained_cvae.guide,
132 | inputs, outputs).detach().item() / inputs.size(0)
133 |
134 | df.iloc[0, 0] = baseline_cll / num_preds
135 | df.iloc[1, 0] = cvae_mc_cll / num_preds
136 | return df
137 |
138 |
139 | if __name__ == '__main__':
140 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
141 |
142 | # Dataset
143 | datasets, dataloaders, dataset_sizes = get_data(
144 | num_quadrant_inputs=1,
145 | batch_size=128
146 | )
147 | baseline_net = BaselineNet(500, 500)
148 | baseline_net.load_state_dict(
149 | torch.load('/Users/carlossouza/Downloads/baseline_net_q1.pth',
150 | map_location='cpu'))
151 | baseline_net.eval()
152 |
153 | cvae_net = CVAE(200, 500, 500, baseline_net)
154 | cvae_net.load_state_dict(
155 | torch.load('/Users/carlossouza/Downloads/cvae_net_q1.pth',
156 | map_location='cpu'))
157 | cvae_net.eval()
158 |
159 | visualize(
160 | device=device,
161 | num_quadrant_inputs=1,
162 | pre_trained_baseline=baseline_net,
163 | pre_trained_cvae=cvae_net,
164 | num_images=10,
165 | num_samples=10
166 | )
167 |
168 | # df = generate_table(
169 | # device=device,
170 | # num_quadrant_inputs=1,
171 | # pre_trained_baseline=baseline_net,
172 | # pre_trained_cvae=cvae_net,
173 | # num_particles=10,
174 | # col_name='{} quadrant'.format(1)
175 | # )
176 | # print(df)
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
--------------------------------------------------------------------------------
/src/video.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pandas as pd
3 | from sys import platform
4 | if platform == 'linux':
5 | import matplotlib
6 | matplotlib.use('Agg')
7 |
8 | import matplotlib.pyplot as plt
9 | import matplotlib.animation as animation
10 | import matplotlib.gridspec as gridspec
11 | import numpy as np
12 | from mnist import get_val_images
13 | import cv2
14 | from PIL import Image
15 |
16 |
17 | fig = plt.figure(figsize=(8, 7))
18 | spec = gridspec.GridSpec(nrows=2, ncols=3, hspace=.3, wspace=.25)
19 | axl = fig.add_subplot(spec[0, :])
20 | axd = [fig.add_subplot(spec[1, i]) for i in range(3)]
21 | plt.subplots_adjust(left=0.10, bottom=0.05, right=0.90, top=0.90,
22 | wspace=0.3, hspace=0.25)
23 | im = Image.open('../data/pyro_logo.png')
24 | im.thumbnail((900, 900), Image.ANTIALIAS)
25 |
26 | im2 = im.copy()
27 | im2.putalpha(30)
28 | im.paste(im2, im)
29 |
30 | im = np.array(im).astype(np.float) / 255
31 | fig.figimage(im, 600, 850) # 1200, 800
32 |
33 |
34 | def animate(i, dfs, dfl, inputs, digits):
35 | if i < len(dfl):
36 | axl.clear()
37 | axl.set_ylim(top=200, bottom=60)
38 | axl.plot(dfl.iloc[0:i, 0].values, dfl.iloc[0:i, 1].values)
39 | axl.set_ylabel('Loss')
40 | axl.set_xlabel('Epochs')
41 | axl.set_title('Training Progress')
42 |
43 | s = f'Loss: {dfl.iloc[i - 10:i, 1].mean():.2f}'
44 | axl.text(0.85, 0.96, s, horizontalalignment='left',
45 | verticalalignment='top', transform=axl.transAxes)
46 | s = f'Epoch: {dfl.iloc[i, 0]:.1f}'
47 | axl.text(0.85, 0.88, s, horizontalalignment='left',
48 | verticalalignment='top', transform=axl.transAxes)
49 |
50 | data = dfs[dfs['epoch'] == dfl['epoch'].iloc[i]]
51 | if len(data) > 0:
52 | for j, k in enumerate([0, 18, 4]): # index 15 also good 3-5 confusion
53 | img = data.iloc[k, :784].values.reshape(28, 28)
54 | inp = inputs[k]
55 | img[inp != -1] = inp[inp != -1]
56 |
57 | img = cv2.resize(img, dsize=(280, 280),
58 | interpolation=cv2.INTER_NEAREST)
59 | img = np.stack((img,)*3, axis=-1)
60 | img[140:, 140, 1] = 1
61 | img[140, :140, 1] = 1
62 |
63 | axd[j].clear()
64 | axd[j].imshow(img, cmap='gray')
65 | axd[j].set_title('Sample %d' % (digits[k]))
66 | axd[j].get_xaxis().set_visible(False)
67 | axd[j].get_yaxis().set_visible(False)
68 |
69 |
70 | def main(args):
71 | dfs = pd.read_csv('../data/samples.csv')
72 | dfl = pd.read_csv('../data/losses.csv')
73 |
74 | inputs, digits = get_val_images(num_quadrant_inputs=1,
75 | num_images=30, shuffle=False)
76 | inputs = inputs.numpy()
77 |
78 | Writer = animation.writers['ffmpeg']
79 | writer = Writer(fps=15, metadata=dict(artist='Carlos Souza'), bitrate=1800)
80 |
81 | ani = animation.FuncAnimation(fig, animate, interval=50, frames=100000,
82 | fargs=(dfs, dfl, inputs, digits, ))
83 |
84 | if args.show:
85 | plt.show()
86 | else:
87 | ani.save('animation.mp4', writer=writer)
88 |
89 |
90 | if __name__ == '__main__':
91 | parser = argparse.ArgumentParser(description='Generate animation.')
92 | parser.add_argument('-s', '--show', action="store_true", default=False,
93 | help='Use this flag to show video animation on screen'
94 | 'instead of saving it to file. Default is to save'
95 | 'to "animation.mp4".')
96 | args = parser.parse_args()
97 | main(args)
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
--------------------------------------------------------------------------------