├── .gitignore
├── LICENSE
├── README.md
├── _readme_figs
├── celeba_sample_1.png
├── celeba_sample_1_downsc.png
├── cifar_sample_1.png
├── eq_layer_inspection.png
├── eqns.txt
├── layers_celeba
│ ├── sample_mode_layer10.png
│ ├── sample_mode_layer13.png
│ ├── sample_mode_layer15.png
│ ├── sample_mode_layer19.png
│ └── sample_mode_layer5.png
├── layers_cifar
│ ├── sample_mode_layer14.png
│ ├── sample_mode_layer2.png
│ ├── sample_mode_layer6.png
│ └── sample_mode_layer9.png
├── layers_mnist
│ ├── sample_mode_layer11.png
│ ├── sample_mode_layer3.png
│ ├── sample_mode_layer7.png
│ └── sample_mode_layer9.png
├── layers_multidsprites
│ ├── sample_mode_layer10.png
│ ├── sample_mode_layer11.png
│ ├── sample_mode_layer2.png
│ ├── sample_mode_layer6.png
│ └── sample_mode_layer9.png
├── layers_svhn
│ ├── sample_mode_layer12.png
│ ├── sample_mode_layer14.png
│ ├── sample_mode_layer3.png
│ └── sample_mode_layer9.png
├── lvae_eq.png
├── mnist_sample_1.png
├── multidsprites_sample_1_downsc.png
└── svhn_sample_1.png
├── data
├── multi-dsprites-binary-rgb
│ └── multi_dsprites_color_012.npz
└── multi_mnist
│ └── multi_binary_mnist_012.npz
├── evaluate.py
├── experiment
├── __init__.py
├── data.py
└── experiment_manager.py
├── lib
├── datasets.py
├── likelihoods.py
├── nn.py
└── stochastic.py
├── main.py
├── models
├── __init__.py
├── lvae.py
└── lvae_layers.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | .DS_Store
3 | __pycache__/
4 | checkpoints/
5 | data/
6 | evaluation_results/
7 | results/
8 | tensorboard_logs/
9 | *.pyc
10 |
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Andrea Dittadi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Ladder Variational Autoencoders (LVAE)
2 |
3 | PyTorch implementation of Ladder Variational Autoencoders (LVAE) [1]:
4 |
5 |
6 | 
7 |
8 | where the variational distributions _q_ at each layer are multivariate Normal
9 | with diagonal covariance.
10 |
11 | **Significant differences from [1]** include:
12 | - skip connections in the generative path: conditioning on _all_ layers above
13 | rather than only on _the_ layer above (see for example [2])
14 | - spatial (convolutional) latent variables
15 | - free bits [3] instead of beta annealing [4]
16 |
17 | ### Install requirements and run MNIST example
18 |
19 | ```
20 | pip install -r requirements.txt
21 | CUDA_VISIBLE_DEVICES=0 python main.py --zdims 32 32 32 --downsample 1 1 1 --nonlin elu --skip --blocks-per-layer 4 --gated --freebits 0.5 --learn-top-prior --data-dep-init --seed 42 --dataset static_mnist
22 | ```
23 |
24 | Dependencies include [boilr](https://github.com/addtt/boiler-pytorch) (a framework
25 | for PyTorch) and [multiobject](https://github.com/addtt/multi-object-datasets)
26 | (which provides multi-object datasets with PyTorch dataloaders).
27 |
28 |
29 |
30 | ## Likelihood results
31 |
32 | Log likelihood bounds on the test set (average over 4 random seeds).
33 |
34 | | dataset | num layers | -ELBO | - log _p(x)_ ≤
[100 iws] | - log _p(x)_ ≤
[1000 iws] |
35 | | -------------------- |:----------:|:------------:|:-------------:|:--------------:|
36 | | binarized MNIST | 3 | 82.14 | 79.47 | 79.24 |
37 | | binarized MNIST | 6 | 80.74 | 78.65 | 78.52 |
38 | | binarized MNIST | 12 | 80.50 | 78.50 | 78.30 |
39 | | multi-dSprites (0-2) | 12 | 26.9 | 23.2 | |
40 | | SVHN | 15 | 4012 (1.88) | 3973 (1.87) | |
41 | | CIFAR10 | 3 | 7651 (3.59) | 7591 (3.56) | |
42 | | CIFAR10 | 6 | 7321 (3.44) | 7268 (3.41) | |
43 | | CIFAR10 | 15 | 7128 (3.35) | 7068 (3.32) | |
44 | | CelebA | 20 | 20026 (2.35) | 19913 (2.34) | |
45 |
46 | Note:
47 | - Bits per dimension in brackets.
48 | - 'iws' stands for importance weighted samples. More samples means tighter log
49 | likelihood lower bound. The bound converges to the actual log likelihood as
50 | the number of samples goes to infinity [5]. Note that the model is always
51 | trained with the ELBO (1 sample).
52 | - Each pixel in the images is modeled independently. The likelihood is Bernoulli
53 | for binary images, and discretized mixture of logistics with 10
54 | components [6] otherwise.
55 | - One day I'll get around to evaluating the IW bound on all datasets with 10000 samples.
56 |
57 |
58 | ## Supported datasets
59 |
60 | - Statically binarized MNIST [7], see Hugo Larochelle's website `http://www.cs.toronto.edu/~larocheh/public/datasets/`
61 | - [SVHN](http://ufldl.stanford.edu/housenumbers/)
62 | - [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html)
63 | - [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) rescaled and cropped to 64x64 – see code for details. The path in `experiment.data.DatasetLoader` has to be modified
64 | - [binary multi-dSprites](https://github.com/addtt/multi-object-datasets): 64x64 RGB shapes (0 to 2) in each image
65 |
66 |
67 | ## Samples
68 |
69 | #### Binarized MNIST
70 |
71 | 
72 |
73 | #### Multi-dSprites
74 |
75 | 
76 |
77 | #### SVHN
78 |
79 | 
80 |
81 | #### CIFAR
82 |
83 | 
84 |
85 | #### CelebA
86 |
87 | 
88 |
89 |
90 | ## Hierarchical representations
91 |
92 | Here we try to visualize the representations learned by individual layers.
93 | We can get a rough idea of what's going on at layer _i_ as follows:
94 |
95 | - Sample latent variables from all layers above layer _i_ (Eq. 1).
96 |
97 | - With these variables fixed, take _S_ conditional samples at layer _i_ (Eq. 2). Note
98 | that they are all conditioned on the same samples. These correspond to one
99 | row in the images below.
100 |
101 | - For each of these samples (each small image in the images below), pick the
102 | mode/mean of the conditional distribution of each layer below (Eq. 3).
103 |
104 | - Finally, sample an image ***x*** given the latent variables (Eq. 4).
105 |
106 | Formally:
107 |
108 |
109 | 
110 |
111 | where _s_ = 1, ..., _S_ denotes the sample index.
112 |
113 | The equations above yield _S_ sample images conditioned on the same values of
114 | ***z*** for layers _i_+1 to _L_. These _S_ samples are shown in one row of the
115 | images below.
116 | Notice that samples from each row are almost identical when the variability comes
117 | from a low-level layer, as such layers mostly model local structure and details.
118 | Higher layers on the other hand model global structure, and we observe more and
119 | more variability in each row as we move to higher layers. When the sampling
120 | happens in the top layer (_i = L_), all samples are completely independent,
121 | even within a row.
122 |
123 | #### Binarized MNIST: layers 4, 8, 10, and 12 (top layer)
124 |
125 | 
126 | 
127 |
128 | 
129 | 
130 |
131 |
132 | #### SVHN: layers 4, 10, 13, and 15 (top layer)
133 |
134 | 
135 | 
136 |
137 | 
138 | 
139 |
140 |
141 | #### CIFAR: layers 3, 7, 10, and 15 (top layer)
142 |
143 | 
144 | 
145 |
146 | 
147 | 
148 |
149 |
150 | #### CelebA: layers 6, 11, 16, and 20 (top layer)
151 |
152 | 
153 |
154 | 
155 |
156 | 
157 |
158 | 
159 |
160 |
161 | #### Multi-dSprites: layers 3, 7, 10, and 12 (top layer)
162 |
163 | 
164 | 
165 |
166 | 
167 | 
168 |
169 |
170 |
171 | ## Experimental details
172 |
173 | I did not perform an extensive hyperparameter search, but this worked pretty well:
174 |
175 | - Downsampling by a factor of 2 in the beginning of inference.
176 | After that, activations are downsampled 4 times for 64x64 images (CelebA and
177 | multi-dSprites), and 3 times otherwise. The spatial size of the final feature
178 | map is always 2x2.
179 | Between these downsampling steps there is approximately the same number of
180 | stochastic layers.
181 | - 4 residual blocks between stochastic layers. Haven't tried with more
182 | than 4 though, as models become quite big and we get diminishing returns.
183 | - The deterministic parts of bottom-up and top-down architecture are (almost)
184 | perfectly mirrored for simplicity.
185 | - Stochastic layers have spatial random variables, and the number of rvs per
186 | "location" (i.e. number of channels of the feature map after sampling from a
187 | layer) is 32 in all layers.
188 | - All other feature maps in deterministic paths have 64 channels.
189 | - Skip connections in the generative model (`--skip`).
190 | - Gated residual blocks (`--gated`).
191 | - Learned prior of the top layer (`--learn-top-prior`).
192 | - A form of data-dependent initialization of weights (`--data-dep-init`).
193 | See code for details.
194 | - freebits=1.0 in experiments with more than 6 stochastic layers, and 0.5 for
195 | smaller models.
196 | - For everything else, see `_add_args()` in `experiment/experiment_manager.py`.
197 |
198 | With these settings, the number of parameters is roughly 1M per stochastic
199 | layer. I tried to control for this by experimenting e.g. with half the number
200 | of layers but twice the number of residual blocks, but it looks like the number
201 | of stochastic layers is what matters the most.
202 |
203 |
204 | ## References
205 |
206 | [1] CK Sønderby,
207 | T Raiko,
208 | L Maaløe,
209 | SK Sønderby,
210 | O Winther.
211 | _Ladder Variational Autoencoders_, NIPS 2016
212 |
213 | [2] L Maaløe, M Fraccaro, V Liévin, O Winther.
214 | _BIVA: A Very Deep Hierarchy of Latent Variables for Generative Modeling_,
215 | NeurIPS 2019
216 |
217 | [3] DP Kingma,
218 | T Salimans,
219 | R Jozefowicz,
220 | X Chen,
221 | I Sutskever,
222 | M Welling.
223 | _Improved Variational Inference with Inverse Autoregressive Flow_,
224 | NIPS 2016
225 |
226 | [4] I Higgins, L Matthey, A Pal, C Burgess, X Glorot, M Botvinick,
227 | S Mohamed, A Lerchner.
228 | _beta-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework_,
229 | ICLR 2017
230 |
231 | [5] Y Burda, RB Grosse, R Salakhutdinov.
232 | _Importance Weighted Autoencoders_,
233 | ICLR 2016
234 |
235 | [6] T Salimans, A Karpathy, X Chen, DP Kingma.
236 | _PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture Likelihood and Other Modifications_,
237 | ICLR 2017
238 |
239 | [7] H Larochelle, I Murray.
240 | _The neural autoregressive distribution estimator_,
241 | AISTATS 2011
242 |
--------------------------------------------------------------------------------
/_readme_figs/celeba_sample_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/celeba_sample_1.png
--------------------------------------------------------------------------------
/_readme_figs/celeba_sample_1_downsc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/celeba_sample_1_downsc.png
--------------------------------------------------------------------------------
/_readme_figs/cifar_sample_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/cifar_sample_1.png
--------------------------------------------------------------------------------
/_readme_figs/eq_layer_inspection.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/eq_layer_inspection.png
--------------------------------------------------------------------------------
/_readme_figs/eqns.txt:
--------------------------------------------------------------------------------
1 | \begin{align*}
2 | p_{\theta}(\mathbf{x}, \mathbf{z}) &= p_{\theta}(\mathbf{z}) \ p_{\theta}(\mathbf{x} \;|\; \mathbf{z})\\[.7em]
3 | p_{\theta}(\mathbf{z}) &=p_{\theta}\left(\mathbf{z}_{L}\right) \ \prod_{i=1}^{L-1} p_{\theta}\left(\mathbf{z}_{i} \;|\; \mathbf{z}_{> i}\right) \\
4 | q_{\phi, \theta}(\mathbf{z} \;|\; \mathbf{x}) &=q_{\phi, \theta}\left(\mathbf{z}_{L} \;|\; \mathbf{x}\right) \ \prod_{i=1}^{L-1} q_{\phi, \theta}\left(\mathbf{z}_{i} \;|\; \mathbf{z}_{>i}\;, \mathbf{x} \right)
5 | \end{align*}
6 |
7 | -----
8 |
9 | \newcommand{\z}{\mathbf{z}}
10 | \newcommand{\x}{\mathbf{x}}
11 | \newcommand{\argmax}[1]{\underset{#1}{\operatorname{arg}\operatorname{max}}\;}
12 |
13 | \begin{align}
14 | \z_{>i} &\sim p(\z_{>i}) = p(\z_L) \prod_{j=i+1}^{L-1} p(\z_j \;|\; \z_{>j})\\
15 | \z_i^{(s)} &\sim p(\z_i \;|\; \z_{>i})\\[1em]
16 | % \z_{j}^{(s)} &= \mathbb{E}_{p(\z_{j} \;|\; \z_{j+1:i}^{(s)}, \z_{>i})} \left[ \z_{j} \right], \qquad \mbox{for } ji}) \qquad \mbox{for } ji} )
19 | \end{align}
20 |
--------------------------------------------------------------------------------
/_readme_figs/layers_celeba/sample_mode_layer10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_celeba/sample_mode_layer10.png
--------------------------------------------------------------------------------
/_readme_figs/layers_celeba/sample_mode_layer13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_celeba/sample_mode_layer13.png
--------------------------------------------------------------------------------
/_readme_figs/layers_celeba/sample_mode_layer15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_celeba/sample_mode_layer15.png
--------------------------------------------------------------------------------
/_readme_figs/layers_celeba/sample_mode_layer19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_celeba/sample_mode_layer19.png
--------------------------------------------------------------------------------
/_readme_figs/layers_celeba/sample_mode_layer5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_celeba/sample_mode_layer5.png
--------------------------------------------------------------------------------
/_readme_figs/layers_cifar/sample_mode_layer14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_cifar/sample_mode_layer14.png
--------------------------------------------------------------------------------
/_readme_figs/layers_cifar/sample_mode_layer2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_cifar/sample_mode_layer2.png
--------------------------------------------------------------------------------
/_readme_figs/layers_cifar/sample_mode_layer6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_cifar/sample_mode_layer6.png
--------------------------------------------------------------------------------
/_readme_figs/layers_cifar/sample_mode_layer9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_cifar/sample_mode_layer9.png
--------------------------------------------------------------------------------
/_readme_figs/layers_mnist/sample_mode_layer11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_mnist/sample_mode_layer11.png
--------------------------------------------------------------------------------
/_readme_figs/layers_mnist/sample_mode_layer3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_mnist/sample_mode_layer3.png
--------------------------------------------------------------------------------
/_readme_figs/layers_mnist/sample_mode_layer7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_mnist/sample_mode_layer7.png
--------------------------------------------------------------------------------
/_readme_figs/layers_mnist/sample_mode_layer9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_mnist/sample_mode_layer9.png
--------------------------------------------------------------------------------
/_readme_figs/layers_multidsprites/sample_mode_layer10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_multidsprites/sample_mode_layer10.png
--------------------------------------------------------------------------------
/_readme_figs/layers_multidsprites/sample_mode_layer11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_multidsprites/sample_mode_layer11.png
--------------------------------------------------------------------------------
/_readme_figs/layers_multidsprites/sample_mode_layer2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_multidsprites/sample_mode_layer2.png
--------------------------------------------------------------------------------
/_readme_figs/layers_multidsprites/sample_mode_layer6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_multidsprites/sample_mode_layer6.png
--------------------------------------------------------------------------------
/_readme_figs/layers_multidsprites/sample_mode_layer9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_multidsprites/sample_mode_layer9.png
--------------------------------------------------------------------------------
/_readme_figs/layers_svhn/sample_mode_layer12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_svhn/sample_mode_layer12.png
--------------------------------------------------------------------------------
/_readme_figs/layers_svhn/sample_mode_layer14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_svhn/sample_mode_layer14.png
--------------------------------------------------------------------------------
/_readme_figs/layers_svhn/sample_mode_layer3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_svhn/sample_mode_layer3.png
--------------------------------------------------------------------------------
/_readme_figs/layers_svhn/sample_mode_layer9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_svhn/sample_mode_layer9.png
--------------------------------------------------------------------------------
/_readme_figs/lvae_eq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/lvae_eq.png
--------------------------------------------------------------------------------
/_readme_figs/mnist_sample_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/mnist_sample_1.png
--------------------------------------------------------------------------------
/_readme_figs/multidsprites_sample_1_downsc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/multidsprites_sample_1_downsc.png
--------------------------------------------------------------------------------
/_readme_figs/svhn_sample_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/svhn_sample_1.png
--------------------------------------------------------------------------------
/data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz
--------------------------------------------------------------------------------
/data/multi_mnist/multi_binary_mnist_012.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/data/multi_mnist/multi_binary_mnist_012.npz
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | """
2 | Standalone script for a couple of simple evaluations/tests of trained models.
3 | """
4 |
5 | import argparse
6 | import os
7 | import warnings
8 |
9 | import torch
10 | import torch.utils.data
11 | from boilr.eval import BaseOfflineEvaluator
12 | from boilr.utils.viz import img_grid_pad_value
13 | from torchvision.utils import save_image
14 |
15 | from experiment.experiment_manager import LVAEExperiment
16 |
17 |
18 | class Evaluator(BaseOfflineEvaluator):
19 |
20 | def run(self):
21 |
22 | torch.set_grad_enabled(False)
23 |
24 | n = 12
25 |
26 | e = self._experiment
27 | e.model.eval()
28 |
29 | # Run evaluation and print results
30 | results = e.test_procedure(iw_samples=self.args.ll_samples)
31 | print("Eval results:\n{}".format(results))
32 |
33 | # Save samples
34 | for i in range(self.args.prior_samples):
35 | fname = os.path.join(self._img_folder, "samples_{}.png".format(i))
36 | e.generate_and_save_samples(fname, nrows=n)
37 |
38 | # Save input and reconstructions
39 | x, y = next(iter(e.dataloaders.test))
40 | fname = os.path.join(self._img_folder, "reconstructions.png")
41 | e.generate_and_save_reconstructions(x, fname, nrows=n)
42 |
43 | # Inspect representations learned by each layer
44 | if self.args.inspect_layer_repr:
45 | inspect_layer_repr(e.model, self._img_folder, n=n)
46 |
47 | # @classmethod
48 | # def _define_args_defaults(cls) -> dict:
49 | # defaults = super(Evaluator, cls)._define_args_defaults()
50 | # return defaults
51 |
52 | def _add_args(self, parser: argparse.ArgumentParser) -> None:
53 |
54 | super(Evaluator, self)._add_args(parser)
55 |
56 | parser.add_argument('--ll',
57 | action='store_true',
58 | help="estimate log likelihood with importance-"
59 | "weighted bound")
60 | parser.add_argument('--ll-samples',
61 | type=int,
62 | default=100,
63 | dest='ll_samples',
64 | metavar='N',
65 | help="number of importance-weighted samples for "
66 | "log likelihood estimation")
67 | parser.add_argument('--ps',
68 | type=int,
69 | default=1,
70 | dest='prior_samples',
71 | metavar='N',
72 | help="number of batches of samples from prior")
73 | parser.add_argument('--layer-repr',
74 | action='store_true',
75 | dest='inspect_layer_repr',
76 | help='inspect layer representations. Generate '
77 | 'samples by sampling top layers once, then taking '
78 | 'many samples from a middle layer, and finally '
79 | 'sample the downstream layers from the conditional '
80 | 'mode. Do this for every layer.')
81 |
82 | @classmethod
83 | def _check_args(cls, args: argparse.Namespace) -> argparse.Namespace:
84 | args = super(Evaluator, cls)._check_args(args)
85 |
86 | if not args.ll:
87 | args.ll_samples = 1
88 | if args.load_step is not None:
89 | warnings.warn(
90 | "Loading weights from specific training step is not supported "
91 | "for now. The model will be loaded from the last checkpoint.")
92 | return args
93 |
94 |
95 | def inspect_layer_repr(model, img_folder, n=8):
96 | for i in range(model.n_layers):
97 |
98 | # print('layer', i)
99 |
100 | mode_layers = range(i)
101 | constant_layers = range(i + 1, model.n_layers)
102 |
103 | # Sample top layers once, then take many samples of a middle layer,
104 | # then sample from the mode in all downstream layers.
105 | sample = []
106 | for r in range(n):
107 | sample.append(
108 | model.sample_prior(n,
109 | mode_layers=mode_layers,
110 | constant_layers=constant_layers))
111 | sample = torch.cat(sample)
112 | pad_value = img_grid_pad_value(sample)
113 | fname = os.path.join(img_folder, 'sample_mode_layer' + str(i) + '.png')
114 | save_image(sample, fname, nrow=n, pad_value=pad_value)
115 |
116 |
117 | def main():
118 | evaluator = Evaluator(experiment_class=LVAEExperiment)
119 | evaluator()
120 |
121 |
122 | if __name__ == "__main__":
123 | main()
124 |
--------------------------------------------------------------------------------
/experiment/__init__.py:
--------------------------------------------------------------------------------
1 | from .experiment_manager import LVAEExperiment
2 |
--------------------------------------------------------------------------------
/experiment/data.py:
--------------------------------------------------------------------------------
1 | from multiobject.pytorch import MultiObjectDataset, MultiObjectDataLoader
2 | from torch.utils.data import DataLoader
3 | from torchvision import transforms
4 | from torchvision.datasets import CIFAR10, SVHN, CelebA
5 |
6 | from lib.datasets import StaticBinaryMnist
7 |
8 | multiobject_paths = {
9 | 'multi_mnist_binary':
10 | './data/multi_mnist/multi_binary_mnist_012.npz',
11 | 'multi_dsprites_binary_rgb':
12 | './data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz',
13 | }
14 | multiobject_datasets = multiobject_paths.keys()
15 |
16 |
17 | class DatasetLoader:
18 | """
19 | Wrapper for DataLoaders. Data attributes:
20 | - train: DataLoader object for training set
21 | - test: DataLoader object for test set
22 | - data_shape: shape of each data point (channels, height, width)
23 | - img_size: spatial dimensions of each data point (height, width)
24 | - color_ch: number of color channels
25 | """
26 |
27 | def __init__(self, args, cuda):
28 |
29 | kwargs = {'num_workers': 1, 'pin_memory': False} if cuda else {}
30 |
31 | # Default dataloader class
32 | dataloader_class = DataLoader
33 |
34 | if args.dataset_name == 'static_mnist':
35 | data_folder = './data/static_bin_mnist/'
36 | train_set = StaticBinaryMnist(data_folder,
37 | train=True,
38 | download=True,
39 | shuffle_init=True)
40 | test_set = StaticBinaryMnist(data_folder,
41 | train=False,
42 | download=True,
43 | shuffle_init=True)
44 |
45 | elif args.dataset_name == 'cifar10':
46 | # Discrete values 0, 1/255, ..., 254/255, 1
47 | transform = transforms.Compose([
48 | # Move values to the center of 256 bins
49 | # transforms.Lambda(lambda x: Image.eval(
50 | # x, lambda y: y * (255/256) + 1/512)),
51 | transforms.ToTensor(),
52 | ])
53 | data_folder = './data/cifar10/'
54 | train_set = CIFAR10(data_folder,
55 | train=True,
56 | download=True,
57 | transform=transform)
58 | test_set = CIFAR10(data_folder,
59 | train=False,
60 | download=True,
61 | transform=transform)
62 |
63 | elif args.dataset_name == 'svhn':
64 | transform = transforms.ToTensor()
65 | data_folder = './data/svhn/'
66 | train_set = SVHN(data_folder,
67 | split='train',
68 | download=True,
69 | transform=transform)
70 | test_set = SVHN(data_folder,
71 | split='test',
72 | download=True,
73 | transform=transform)
74 |
75 | elif args.dataset_name == 'celeba':
76 | transform = transforms.Compose([
77 | transforms.CenterCrop(148),
78 | transforms.Resize((64, 64)),
79 | transforms.ToTensor(),
80 | ])
81 | data_folder = '/scratch/adit/data/celeba/'
82 | train_set = CelebA(data_folder,
83 | split='train',
84 | download=True,
85 | transform=transform)
86 | test_set = CelebA(data_folder,
87 | split='valid',
88 | download=True,
89 | transform=transform)
90 |
91 | elif args.dataset_name in multiobject_datasets:
92 | data_path = multiobject_paths[args.dataset_name]
93 | train_set = MultiObjectDataset(data_path, train=True)
94 | test_set = MultiObjectDataset(data_path, train=False)
95 |
96 | # Custom data loader class
97 | dataloader_class = MultiObjectDataLoader
98 |
99 | else:
100 | raise RuntimeError("Unrecognized data set '{}'".format(
101 | args.dataset_name))
102 |
103 | self.train = dataloader_class(train_set,
104 | batch_size=args.batch_size,
105 | shuffle=True,
106 | drop_last=True,
107 | **kwargs)
108 | self.test = dataloader_class(test_set,
109 | batch_size=args.test_batch_size,
110 | shuffle=False,
111 | **kwargs)
112 |
113 | self.data_shape = self.train.dataset[0][0].size()
114 | self.img_size = self.data_shape[1:]
115 | self.color_ch = self.data_shape[0]
116 |
--------------------------------------------------------------------------------
/experiment/experiment_manager.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import boilr.data
4 | import torch
5 | from boilr import VAEExperimentManager
6 | from boilr.nn.init import data_dependent_init
7 | from boilr.utils import linear_anneal
8 | from torch import optim
9 | from torch.optim.optimizer import Optimizer
10 | from typing import Optional
11 | from models.lvae import LadderVAE
12 | from .data import DatasetLoader
13 |
14 | boilr.set_options(model_print_depth=2)
15 |
16 |
17 | class LVAEExperiment(VAEExperimentManager):
18 | """
19 | Experiment manager.
20 |
21 | Data attributes:
22 | - 'args': argparse.Namespace containing all config parameters. When
23 | initializing this object, if 'args' is not given, all config
24 | parameters are set based on experiment defaults and user input, using
25 | argparse.
26 | - 'run_description': string description of the run that includes a timestamp
27 | and can be used e.g. as folder name for logging.
28 | - 'model': the model.
29 | - 'device': torch.device that is being used
30 | - 'dataloaders': DataLoaders, with attributes 'train' and 'test'
31 | - 'optimizer': the optimizer
32 | """
33 |
34 | def _make_datamanager(self) -> boilr.data.BaseDatasetManager:
35 | cuda = self.device.type == 'cuda'
36 | return DatasetLoader(self.args, cuda)
37 |
38 | def _make_model(self) -> torch.nn.Module:
39 | args = self.args
40 | model = LadderVAE(
41 | self.dataloaders.color_ch,
42 | z_dims=args.z_dims,
43 | blocks_per_layer=args.blocks_per_layer,
44 | downsample=args.downsample,
45 | merge_type=args.merge_layers,
46 | batchnorm=args.batch_norm,
47 | nonlin=args.nonlin,
48 | stochastic_skip=args.skip_connections,
49 | n_filters=args.n_filters,
50 | dropout=args.dropout,
51 | res_block_type=args.residual_type,
52 | free_bits=args.free_bits,
53 | learn_top_prior=args.learn_top_prior,
54 | img_shape=self.dataloaders.img_size,
55 | likelihood_form=args.likelihood,
56 | gated=args.gated,
57 | no_initial_downscaling=args.no_initial_downscaling,
58 | analytical_kl=args.analytical_kl,
59 | ).to(self.device)
60 |
61 | # Weight initialization
62 | if args.simple_data_dependent_init:
63 |
64 | # Get batch
65 | t = [
66 | self.dataloaders.train.dataset[i]
67 | for i in range(args.batch_size)
68 | ]
69 | t = torch.stack(tuple(t[i][0] for i in range(len(t))))
70 |
71 | # Use batch for data dependent init
72 | data_dependent_init(model, {'x': t.to(self.device)})
73 |
74 | return model
75 |
76 | def _make_optimizer(self) -> Optimizer:
77 | args = self.args
78 | optimizer = optim.Adamax(self.model.parameters(),
79 | lr=args.lr,
80 | weight_decay=args.weight_decay)
81 | return optimizer
82 |
83 | @classmethod
84 | def _define_args_defaults(cls) -> dict:
85 | defaults = super(LVAEExperiment, cls)._define_args_defaults()
86 |
87 | # Override boilr defaults
88 | defaults.update(
89 |
90 | # General
91 | batch_size=64,
92 | test_batch_size=1000,
93 | lr=3e-4,
94 | train_log_every=10000,
95 | test_log_every=10000,
96 | checkpoint_every=100000,
97 | keep_checkpoint_max=2,
98 | resume="",
99 |
100 | # VI-specific
101 | loglikelihood_every=50000,
102 | loglikelihood_samples=100,
103 | )
104 |
105 | return defaults
106 |
107 | def _add_args(self, parser: argparse.ArgumentParser) -> None:
108 |
109 | super(LVAEExperiment, self)._add_args(parser)
110 |
111 | def list_options(lst):
112 | if lst:
113 | return "'" + "' | '".join(lst) + "'"
114 | return ""
115 |
116 | legal_merge_layers = ['linear', 'residual']
117 | legal_nonlin = ['relu', 'leakyrelu', 'elu', 'selu']
118 | legal_resblock = ['cabdcabd', 'bacdbac', 'bacdbacd']
119 | legal_datasets = [
120 | 'static_mnist', 'cifar10', 'celeba', 'svhn',
121 | 'multi_dsprites_binary_rgb', 'multi_mnist_binary'
122 | ]
123 | legal_likelihoods = [
124 | 'bernoulli', 'gaussian', 'discr_log', 'discr_log_mix'
125 | ]
126 |
127 | parser.add_argument('-d',
128 | '--dataset',
129 | type=str,
130 | choices=legal_datasets,
131 | default='static_mnist',
132 | metavar='NAME',
133 | dest='dataset_name',
134 | help="dataset: " + list_options(legal_datasets))
135 |
136 | parser.add_argument('--likelihood',
137 | type=str,
138 | choices=legal_likelihoods,
139 | metavar='NAME',
140 | dest='likelihood',
141 | help="likelihood: {}; the default depends on the "
142 | "dataset".format(list_options(legal_likelihoods)))
143 |
144 | parser.add_argument('--zdims',
145 | nargs='+',
146 | type=int,
147 | default=[32, 32, 32],
148 | metavar='DIM',
149 | dest='z_dims',
150 | help='list of dimensions (number of channels) for '
151 | 'each stochastic layer')
152 |
153 | parser.add_argument('--blocks-per-layer',
154 | type=int,
155 | default=2,
156 | metavar='N',
157 | help='residual blocks between stochastic layers')
158 |
159 | parser.add_argument('--nfilters',
160 | type=int,
161 | default=64,
162 | metavar='N',
163 | dest='n_filters',
164 | help='number of channels in all residual blocks')
165 |
166 | parser.add_argument('--no-bn',
167 | action='store_true',
168 | dest='no_batch_norm',
169 | help='do not use batch normalization')
170 |
171 | parser.add_argument('--skip',
172 | action='store_true',
173 | dest='skip_connections',
174 | help='skip connections in generative model')
175 |
176 | parser.add_argument('--gated',
177 | action='store_true',
178 | dest='gated',
179 | help='use gated layers in residual blocks')
180 |
181 | parser.add_argument('--downsample',
182 | nargs='+',
183 | type=int,
184 | default=[1, 1, 1],
185 | metavar='N',
186 | help='list of integers, each int is the number of '
187 | 'downsampling steps (by a factor of 2) before each '
188 | 'stochastic layer')
189 |
190 | parser.add_argument('--learn-top-prior',
191 | action='store_true',
192 | help="learn the top-layer prior")
193 |
194 | parser.add_argument('--residual-type',
195 | type=str,
196 | choices=legal_resblock,
197 | default='bacdbacd',
198 | metavar='TYPE',
199 | help="type of residual blocks: " +
200 | list_options(legal_resblock))
201 |
202 | parser.add_argument('--merge-layers',
203 | type=str,
204 | choices=legal_merge_layers,
205 | default='residual',
206 | metavar='TYPE',
207 | help="type of merge layers: " +
208 | list_options(legal_merge_layers))
209 |
210 | parser.add_argument('--beta-anneal',
211 | type=int,
212 | default=0,
213 | metavar='B',
214 | help='steps for annealing beta from 0 to 1')
215 |
216 | parser.add_argument('--data-dep-init',
217 | action='store_true',
218 | dest='simple_data_dependent_init',
219 | help='use simple data-dependent initialization to '
220 | 'normalize outputs of affine layers')
221 |
222 | parser.add_argument('--wd',
223 | type=float,
224 | default=0.0,
225 | dest='weight_decay',
226 | help='weight decay')
227 |
228 | parser.add_argument('--nonlin',
229 | type=str,
230 | choices=legal_nonlin,
231 | default='elu',
232 | metavar='F',
233 | help="nonlinear activation: " +
234 | list_options(legal_nonlin))
235 |
236 | parser.add_argument('--dropout',
237 | type=float,
238 | default=0.2,
239 | metavar='D',
240 | help='dropout probability (in deterministic '
241 | 'layers)')
242 |
243 | parser.add_argument('--freebits',
244 | type=float,
245 | default=0.0,
246 | metavar='N',
247 | dest='free_bits',
248 | help='free bits (nats)')
249 |
250 | parser.add_argument('--analytical-kl',
251 | action='store_true',
252 | dest='analytical_kl',
253 | help='use analytical KL')
254 |
255 | parser.add_argument('--no-initial-downscaling',
256 | action='store_true',
257 | dest='no_initial_downscaling',
258 | help='do not downscale as first inference step (and'
259 | 'upscale as last generation step)')
260 |
261 | @classmethod
262 | def _check_args(cls, args: argparse.Namespace) -> argparse.Namespace:
263 |
264 | args = super(LVAEExperiment, cls)._check_args(args)
265 |
266 | if len(args.z_dims) != len(args.downsample):
267 | msg = ("length of list of latent dimensions ({}) does not match "
268 | "length of list of downsampling factors ({})").format(
269 | len(args.z_dims), len(args.downsample))
270 | raise RuntimeError(msg)
271 |
272 | assert args.weight_decay >= 0.0
273 | assert 0.0 <= args.dropout <= 1.0
274 | if args.dropout < 1e-5:
275 | args.dropout = None
276 | assert args.free_bits >= 0.0
277 | args.batch_norm = not args.no_batch_norm
278 |
279 | likelihood_map = {
280 | 'static_mnist': 'bernoulli',
281 | 'multi_dsprites_binary_rgb': 'bernoulli',
282 | 'multi_mnist_binary': 'bernoulli',
283 | 'cifar10': 'discr_log_mix',
284 | 'celeba': 'discr_log_mix',
285 | 'svhn': 'discr_log_mix',
286 | }
287 | if args.likelihood is None: # default
288 | args.likelihood = likelihood_map[args.dataset_name]
289 |
290 | return args
291 |
292 | @staticmethod
293 | def _make_run_description(args: argparse.Namespace) -> str:
294 | s = ''
295 | s += args.dataset_name
296 | s += ',{}ly'.format(len(args.z_dims))
297 | # s += ',z=' + str(args.z_dims).replace(" ", "")
298 | # s += ',dwn=' + str(args.downsample).replace(" ", "")
299 | s += ',{}bpl'.format(args.blocks_per_layer)
300 | s += ',{}ch'.format(args.n_filters)
301 | if args.skip_connections:
302 | s += ',skip'
303 | if args.gated:
304 | s += ',gate'
305 | s += ',block=' + args.residual_type
306 | if args.beta_anneal != 0:
307 | s += ',b{}'.format(args.beta_anneal)
308 | s += ',{}'.format(args.nonlin)
309 | if args.free_bits > 0:
310 | s += ',freeb={}'.format(args.free_bits)
311 | if args.dropout is not None:
312 | s += ',drop={}'.format(args.dropout)
313 | if args.learn_top_prior:
314 | s += ',learnp'
315 | if args.weight_decay > 0.0:
316 | s += ',wd={}'.format(args.weight_decay)
317 | s += ',seed{}'.format(args.seed)
318 | if len(args.additional_descr) > 0:
319 | s += ',' + args.additional_descr
320 | return s
321 |
322 | def forward_pass(self,
323 | x: torch.Tensor,
324 | y: Optional[torch.Tensor] = None) -> dict:
325 |
326 | # Forward pass
327 | x = x.to(self.device, non_blocking=True)
328 | model_out = self.model(x)
329 | recons_sep = -model_out['ll']
330 | kl_sep = model_out['kl_sep']
331 | kl = model_out['kl']
332 | kl_loss = model_out['kl_loss']
333 |
334 | # ELBO
335 | elbo_sep = -(recons_sep + kl_sep)
336 | elbo = elbo_sep.mean()
337 |
338 | # Loss with beta
339 | beta = 1.
340 | if self.args.beta_anneal != 0:
341 | beta = linear_anneal(self.model.global_step, 0.0, 1.0,
342 | self.args.beta_anneal)
343 | recons = recons_sep.mean()
344 | loss = recons + kl_loss * beta
345 |
346 | # L2
347 | l2 = 0.0
348 | for p in self.model.parameters():
349 | l2 = l2 + torch.sum(p**2)
350 | l2 = l2.sqrt()
351 |
352 | output = {
353 | 'loss': loss,
354 | 'elbo': elbo,
355 | 'elbo_sep': elbo_sep,
356 | 'kl': kl,
357 | 'l2': l2,
358 | 'recons': recons,
359 | 'out_mean': model_out['out_mean'],
360 | 'out_mode': model_out['out_mode'],
361 | 'out_sample': model_out['out_sample'],
362 | 'likelihood_params': model_out['likelihood_params'],
363 | }
364 | if 'kl_avg_layerwise' in model_out:
365 | output['kl_avg_layerwise'] = model_out['kl_avg_layerwise']
366 |
367 | return output
368 |
369 | @classmethod
370 | def train_log_str(cls,
371 | summaries: dict,
372 | step: int,
373 | epoch: Optional[int] = None) -> str:
374 | s = " [step {}] loss: {:.5g} ELBO: {:.5g} recons: {:.3g} KL: {:.3g}"
375 | s = s.format(step, summaries['loss/loss'], summaries['elbo/elbo'],
376 | summaries['elbo/recons'], summaries['elbo/kl'])
377 | return s
378 |
379 | @classmethod
380 | def test_log_str(cls,
381 | summaries: dict,
382 | step: int,
383 | epoch: Optional[int] = None) -> str:
384 | s = " "
385 | if epoch is not None:
386 | s += "[step {}, epoch {}] ".format(step, epoch)
387 | s += "ELBO {:.5g} recons: {:.3g} KL: {:.3g}".format(
388 | summaries['elbo/elbo'], summaries['elbo/recons'],
389 | summaries['elbo/kl'])
390 | ll_key = None
391 | for k in summaries.keys():
392 | if k.find('elbo_IW') > -1:
393 | ll_key = k
394 | iw_samples = k.split('_')[-1]
395 | break
396 | if ll_key is not None:
397 | s += " marginal log-likelihood ({}) {:.5g}".format(
398 | iw_samples, summaries[ll_key])
399 |
400 | return s
401 |
402 | @classmethod
403 | def get_metrics_dict(cls, results: dict) -> dict:
404 | metrics_dict = {
405 | 'loss/loss': results['loss'].item(),
406 | 'elbo/elbo': results['elbo'].item(),
407 | 'elbo/recons': results['recons'].item(),
408 | 'elbo/kl': results['kl'].item(),
409 | 'l2/l2': results['l2'].item(),
410 | }
411 | if 'kl_avg_layerwise' in results:
412 | for i in range(len(results['kl_avg_layerwise'])):
413 | key = 'kl_layers/kl_layer_{}'.format(i)
414 | metrics_dict[key] = results['kl_avg_layerwise'][i].item()
415 | return metrics_dict
416 |
--------------------------------------------------------------------------------
/lib/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from urllib import request
3 |
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import TensorDataset
7 |
8 |
9 | class StaticBinaryMnist(TensorDataset):
10 |
11 | def __init__(self, folder, train, download=False, shuffle_init=False):
12 | self.download = download
13 | if train:
14 | sets = [
15 | self._get_binarized_mnist(folder, shuffle_init, split='train'),
16 | self._get_binarized_mnist(folder, shuffle_init, split='valid')
17 | ]
18 | x = np.concatenate(sets, axis=0)
19 | else:
20 | x = self._get_binarized_mnist(folder, shuffle_init, split='test')
21 | labels = torch.zeros(len(x),).fill_(float('nan'))
22 | super().__init__(torch.from_numpy(x), labels)
23 |
24 | def _get_binarized_mnist(self, folder, shuffle_init, split=None):
25 | """
26 | Get statically binarized MNIST. Code partially taken from
27 | https://github.com/altosaar/proximity_vi/blob/master/get_binary_mnist.py
28 | """
29 |
30 | subdatasets = ['train', 'valid', 'test']
31 | if split not in subdatasets:
32 | raise ValueError("Valid splits: {}".format(subdatasets))
33 | data = {}
34 |
35 | fname = 'binarized_mnist_{}.npz'.format(split)
36 | path = os.path.join(folder, fname)
37 |
38 | if not os.path.exists(path):
39 | print("Dataset file '{}' not found".format(path))
40 | if not self.download:
41 | msg = "Dataset not found, use download=True to download it"
42 | raise RuntimeError(msg)
43 |
44 | print("Downloading whole dataset...")
45 |
46 | os.makedirs(folder, exist_ok=True)
47 |
48 | for subdataset in subdatasets:
49 | fname_mat = 'binarized_mnist_{}.amat'.format(subdataset)
50 | url = ('http://www.cs.toronto.edu/~larocheh/public/datasets/'
51 | 'binarized_mnist/{}'.format(fname_mat))
52 | path_mat = os.path.join(folder, fname_mat)
53 | request.urlretrieve(url, path_mat)
54 |
55 | with open(path_mat) as f:
56 | lines = f.readlines()
57 |
58 | os.remove(path_mat)
59 | lines = np.array(
60 | [[int(i) for i in line.split()] for line in lines])
61 | data[subdataset] = lines.astype('float32').reshape(
62 | (-1, 1, 28, 28))
63 | np.savez_compressed(path_mat.split(".amat")[0],
64 | data=data[subdataset])
65 |
66 | else:
67 | data[split] = np.load(path)['data']
68 |
69 | if shuffle_init:
70 | np.random.shuffle(data[split])
71 |
72 | return data[split]
73 |
74 |
75 | def _pad_tensor(x, size, value=None):
76 | assert isinstance(x, torch.Tensor)
77 | input_size = len(x)
78 | if value is None:
79 | value = float('nan')
80 |
81 | # Copy input tensor into a tensor filled with specified value
82 | # Convert everything to float, not ideal but it's robust
83 | out = torch.zeros(*size, dtype=torch.float)
84 | out.fill_(value)
85 | if input_size > 0: # only if at least one element in the sequence
86 | out[:input_size] = x.float()
87 | return out
88 |
--------------------------------------------------------------------------------
/lib/likelihoods.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import torch
6 | from torch import nn
7 | from torch.distributions import Normal
8 | from torch.nn import functional as F
9 |
10 | from .stochastic import logistic_rsample, sample_from_discretized_mix_logistic
11 |
12 |
13 | class LikelihoodModule(nn.Module):
14 |
15 | def distr_params(self, x):
16 | pass
17 |
18 | @staticmethod
19 | def mean(params):
20 | pass
21 |
22 | @staticmethod
23 | def mode(params):
24 | pass
25 |
26 | @staticmethod
27 | def sample(params):
28 | pass
29 |
30 | def log_likelihood(self, x, params):
31 | pass
32 |
33 | def forward(self, input_, x):
34 | distr_params = self.distr_params(input_)
35 | mean = self.mean(distr_params)
36 | mode = self.mode(distr_params)
37 | sample = self.sample(distr_params)
38 | if x is None:
39 | ll = None
40 | else:
41 | ll = self.log_likelihood(x, distr_params)
42 | dct = {
43 | 'mean': mean,
44 | 'mode': mode,
45 | 'sample': sample,
46 | 'params': distr_params,
47 | }
48 | return ll, dct
49 |
50 |
51 | class BernoulliLikelihood(LikelihoodModule):
52 |
53 | def __init__(self, ch_in, color_channels):
54 | super().__init__()
55 | self.parameter_net = nn.Conv2d(ch_in,
56 | color_channels,
57 | kernel_size=3,
58 | padding=1)
59 |
60 | def distr_params(self, x):
61 | x = self.parameter_net(x)
62 | x = torch.sigmoid(x)
63 | return x
64 |
65 | @staticmethod
66 | def mean(params):
67 | return params
68 |
69 | @staticmethod
70 | def mode(params):
71 | return torch.round(params)
72 |
73 | @staticmethod
74 | def sample(params):
75 | return (torch.rand_like(params) < params).float()
76 |
77 | def log_likelihood(self, x, params):
78 | return log_bernoulli(x, params, reduce='none')
79 |
80 |
81 | class GaussianLikelihood(LikelihoodModule):
82 |
83 | def __init__(self, ch_in, color_channels):
84 | super().__init__()
85 | self.parameter_net = nn.Conv2d(ch_in,
86 | 2 * color_channels,
87 | kernel_size=3,
88 | padding=1)
89 |
90 | def distr_params(self, x):
91 | x = self.parameter_net(x)
92 | mean, lv = x.chunk(2, dim=1)
93 | params = {
94 | 'mean': mean,
95 | 'logvar': lv,
96 | }
97 | return params
98 |
99 | @staticmethod
100 | def mean(params):
101 | return params['mean']
102 |
103 | @staticmethod
104 | def mode(params):
105 | return params['mean']
106 |
107 | @staticmethod
108 | def sample(params):
109 | p = Normal(params['mean'], (params['logvar'] / 2).exp())
110 | return p.rsample()
111 |
112 | def log_likelihood(self, x, params):
113 | logprob = log_normal(x, params['mean'], params['logvar'], reduce='none')
114 | return logprob
115 |
116 |
117 | class DiscretizedLogisticLikelihood(LikelihoodModule):
118 | """
119 | Assume input data to be originally uint8 (0, ..., 255) and then rescaled
120 | by 1/255: discrete values in {0, 1/255, ..., 255/255}.
121 | If using the discretize logistic logprob implementation here, this should
122 | be rescaled by 255/256 and shifted by <1/256 in this class. So the data is
123 | inside 256 bins between 0 and 1.
124 |
125 | Note that mean and logscale are parameters of the underlying continuous
126 | logistic distribution, not of its discretization.
127 | """
128 |
129 | log_scale_bias = -1.
130 |
131 | def __init__(self, ch_in, color_channels, n_bins, double=False):
132 | super().__init__()
133 | self.n_bins = n_bins
134 | self.double_precision = double
135 | self.parameter_net = nn.Conv2d(ch_in,
136 | 2 * color_channels,
137 | kernel_size=3,
138 | padding=1)
139 |
140 | def distr_params(self, x):
141 | x = self.parameter_net(x)
142 | mean, ls = x.chunk(2, dim=1)
143 | ls = ls + self.log_scale_bias
144 | ls = ls.clamp(min=-7.)
145 | mean = mean + 0.5 # initialize to mid interval
146 | params = {
147 | 'mean': mean,
148 | 'logscale': ls,
149 | }
150 | return params
151 |
152 | @staticmethod
153 | def mean(params):
154 | return params['mean']
155 |
156 | @staticmethod
157 | def mode(params):
158 | return params['mean']
159 |
160 | @staticmethod
161 | def sample(params):
162 | # We're not quantizing 8bit, but it doesn't matter
163 | sample = logistic_rsample((params['mean'], params['logscale']))
164 | sample = sample.clamp(min=0., max=1.)
165 | return sample
166 |
167 | def log_likelihood(self, x, params):
168 | # Input data x should be inside (not at the edge) n_bins equally-sized
169 | # bins between 0 and 1. E.g. if n_bins=256 the 257 bin edges are:
170 | # 0, 1/256, ..., 255/256, 1.
171 |
172 | x = x * (255 / 256) + 1 / 512
173 |
174 | logprob = log_discretized_logistic(x,
175 | params['mean'],
176 | params['logscale'],
177 | n_bins=self.n_bins,
178 | reduce='none',
179 | double=self.double_precision)
180 | return logprob
181 |
182 |
183 | class DiscretizedLogisticMixLikelihood(LikelihoodModule):
184 | """
185 | Sampling and loss computation are based on the original tf code.
186 |
187 | Assume input data to be originally uint8 (0, ..., 255) and then rescaled
188 | by 1/255: discrete values in {0, 1/255, ..., 255/255}.
189 | When using the original discretize logistic mixture logprob implementation,
190 | this data should be rescaled to be in [-1, 1].
191 |
192 | Mean and mode are not implemented for now.
193 |
194 | Color channels for now is fixed to 3 and n_bins to 256.
195 | """
196 |
197 | def __init__(self, ch_in, n_components=10):
198 | super().__init__()
199 | self.parameter_net = nn.Conv2d(ch_in,
200 | 10 * n_components,
201 | kernel_size=3,
202 | padding=1)
203 |
204 | def distr_params(self, x):
205 | x = self.parameter_net(x)
206 | params = {
207 | 'mean': None, # TODO
208 | 'all_params': x
209 | }
210 | return params
211 |
212 | @staticmethod
213 | def mean(params):
214 | return params['mean']
215 |
216 | @staticmethod
217 | def mode(params):
218 | return params['mean']
219 |
220 | @staticmethod
221 | def sample(params):
222 | sample = sample_from_discretized_mix_logistic(params['all_params'])
223 | sample = (sample + 1) / 2
224 | sample = sample.clamp(min=0., max=1.)
225 | return sample
226 |
227 | def log_likelihood(self, x, params):
228 | x = x * 2 - 1
229 | logprob = -discretized_mix_logistic_loss(x, params['all_params'])
230 | return logprob
231 |
232 |
233 | def log_discretized_logistic(x,
234 | mean,
235 | log_scale,
236 | n_bins=256,
237 | reduce='mean',
238 | double=False):
239 | """
240 | Log of the probability mass of the values x under the logistic distribution
241 | with parameters mean and scale. The sum is taken over all dimensions except
242 | for the first one (assumed to be batch). Reduction is applied at the end.
243 |
244 | Assume input data to be inside (not at the edge) of n_bins equally-sized
245 | bins between 0 and 1. E.g. if n_bins=256 the 257 bin edges are:
246 | 0, 1/256, ..., 255/256, 1.
247 | If values are at the left edge it's also ok, but let's be on the safe side
248 |
249 | Variance of logistic distribution is
250 | var = scale^2 * pi^2 / 3
251 |
252 | :param x: tensor with shape (batch, channels, dim1, dim2)
253 | :param mean: tensor with mean of distribution, shape
254 | (batch, channels, dim1, dim2)
255 | :param log_scale: tensor with log scale of distribution, shape has to be either
256 | scalar or broadcastable
257 | :param n_bins: bin size (default: 256)
258 | :param reduce: reduction over batch: 'mean' | 'sum' | 'none'
259 | :param double: whether double precision should be used for computations
260 | :return:
261 | """
262 | log_scale = _input_check(x, mean, log_scale, reduce)
263 | if double:
264 | log_scale = log_scale.double()
265 | x = x.double()
266 | mean = mean.double()
267 | eps = 1e-14
268 | else:
269 | eps = 1e-7
270 | # scale = np.sqrt(3) * torch.exp(logvar / 2) / np.pi
271 | scale = log_scale.exp()
272 |
273 | # Set values to the left of each bin
274 | x = torch.floor(x * n_bins) / n_bins
275 |
276 | cdf_plus = torch.ones_like(x)
277 | idx = x < (n_bins - 1) / n_bins
278 | cdf_plus[idx] = torch.sigmoid(
279 | (x[idx] + 1 / n_bins - mean[idx]) / scale[idx])
280 | cdf_minus = torch.zeros_like(x)
281 | idx = x >= 1 / n_bins
282 | cdf_minus[idx] = torch.sigmoid((x[idx] - mean[idx]) / scale[idx])
283 | log_prob = torch.log(cdf_plus - cdf_minus + eps)
284 | log_prob = log_prob.sum((1, 2, 3))
285 | log_prob = _reduce(log_prob, reduce)
286 | if double:
287 | log_prob = log_prob.float()
288 | return log_prob
289 |
290 |
291 | def discretized_mix_logistic_loss(x, l):
292 | """
293 | log-likelihood for mixture of discretized logistics, assumes the data
294 | has been rescaled to [-1,1] interval
295 |
296 | Code taken from pytorch adaptation of original PixelCNN++ tf implementation
297 | https://github.com/pclucas14/pixel-cnn-pp
298 | """
299 |
300 | # channels last
301 | x = x.permute(0, 2, 3, 1)
302 | l = l.permute(0, 2, 3, 1)
303 |
304 | # true image (i.e. labels) to regress to, e.g. (B,32,32,3)
305 | xs = [int(y) for y in x.size()]
306 | # predicted distribution, e.g. (B,32,32,100)
307 | ls = [int(y) for y in l.size()]
308 |
309 | # here and below: unpacking the params of the mixture of logistics
310 | nr_mix = int(ls[-1] / 10)
311 | logit_probs = l[:, :, :, :nr_mix]
312 | l = l[:, :, :, nr_mix:].contiguous().view(
313 | xs + [nr_mix * 3]) # 3 for mean, scale, coef
314 | means = l[:, :, :, :, :nr_mix]
315 | # log_scales = torch.max(l[:, :, :, :, nr_mix:2 * nr_mix], -7.)
316 | log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.)
317 |
318 | coeffs = torch.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix])
319 | # here and below: getting the means and adjusting them based on preceding
320 | # sub-pixels
321 | x = x.contiguous()
322 | x = x.unsqueeze(-1) + nn.Parameter(torch.zeros(xs + [nr_mix]).to(x.device),
323 | requires_grad=False)
324 | m2 = (means[:, :, :, 1, :] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :]).view(
325 | xs[0], xs[1], xs[2], 1, nr_mix)
326 |
327 | m3 = (means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] +
328 | coeffs[:, :, :, 2, :] * x[:, :, :, 1, :]).view(
329 | xs[0], xs[1], xs[2], 1, nr_mix)
330 |
331 | means = torch.cat((means[:, :, :, 0, :].unsqueeze(3), m2, m3), dim=3)
332 | centered_x = x - means
333 | inv_stdv = torch.exp(-log_scales)
334 | plus_in = inv_stdv * (centered_x + 1. / 255.)
335 | cdf_plus = torch.sigmoid(plus_in)
336 | min_in = inv_stdv * (centered_x - 1. / 255.)
337 | cdf_min = torch.sigmoid(min_in)
338 | # log probability for edge case of 0 (before scaling)
339 | log_cdf_plus = plus_in - F.softplus(plus_in)
340 | # log probability for edge case of 255 (before scaling)
341 | log_one_minus_cdf_min = -F.softplus(min_in)
342 | cdf_delta = cdf_plus - cdf_min # probability for all other cases
343 | mid_in = inv_stdv * centered_x
344 | # log probability in the center of the bin, to be used in extreme cases
345 | # (not actually used in our code)
346 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
347 |
348 | # now select the right output: left edge case, right edge case, normal
349 | # case, extremely low prob case (doesn't actually happen for us)
350 |
351 | # this is what we are really doing, but using the robust version below
352 | # for extreme cases in other applications and to avoid NaN issue with tf.select()
353 | # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999,
354 | # log_one_minus_cdf_min, tf.log(cdf_delta)))
355 |
356 | # robust version, that still works if probabilities are below 1e-5 (which
357 | # never happens in our code)
358 | # tensorflow backpropagates through tf.select() by multiplying with zero
359 | # instead of selecting: this requires use to use some ugly tricks to avoid
360 | # potential NaNs
361 | # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as
362 | # output, it's purely there to get around the tf.select() gradient issue
363 | # if the probability on a sub-pixel is below 1e-5, we use an approximation
364 | # based on the assumption that the log-density is constant in the bin of
365 | # the observed sub-pixel value
366 |
367 | inner_inner_cond = (cdf_delta > 1e-5).float()
368 | inner_inner_out = inner_inner_cond * torch.log(
369 | torch.clamp(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * (
370 | log_pdf_mid - np.log(127.5))
371 | inner_cond = (x > 0.999).float()
372 | inner_out = inner_cond * log_one_minus_cdf_min + (
373 | 1. - inner_cond) * inner_inner_out
374 | cond = (x < -0.999).float()
375 | log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
376 | log_probs = torch.sum(log_probs, dim=3) + torch.log_softmax(logit_probs,
377 | dim=-1)
378 | log_probs = torch.logsumexp(log_probs, dim=-1)
379 |
380 | # return -torch.sum(log_probs)
381 | loss_sep = -log_probs.sum((1, 2)) # keep batch dimension
382 | return loss_sep
383 |
384 |
385 | def log_bernoulli(x, mean, reduce='mean'):
386 | log_prob = -F.binary_cross_entropy(mean, x, reduction='none')
387 | log_prob = log_prob.sum((1, 2, 3))
388 | return _reduce(log_prob, reduce)
389 |
390 |
391 | def log_normal(x, mean, logvar, reduce='mean'):
392 | """
393 | Log of the probability density of the values x untder the Normal
394 | distribution with parameters mean and logvar. The sum is taken over all
395 | dimensions except for the first one (assumed to be batch). Reduction
396 | is applied at the end.
397 | :param x: tensor of points, with shape (batch, channels, dim1, dim2)
398 | :param mean: tensor with mean of distribution, shape
399 | (batch, channels, dim1, dim2)
400 | :param logvar: tensor with log-variance of distribution, shape has to be
401 | either scalar or broadcastable
402 | :param reduce: reduction over batch: 'mean' | 'sum' | 'none'
403 | :return:
404 | """
405 |
406 | logvar = _input_check(x, mean, logvar, reduce)
407 | var = torch.exp(logvar)
408 | log_prob = -0.5 * ((
409 | (x - mean)**2) / var + logvar + torch.tensor(2 * math.pi).log())
410 | log_prob = log_prob.sum((1, 2, 3))
411 | return _reduce(log_prob, reduce)
412 |
413 |
414 | def _reduce(x, reduce):
415 | if reduce == 'mean':
416 | x = x.mean()
417 | elif reduce == 'sum':
418 | x = x.sum()
419 | return x
420 |
421 |
422 | def _input_check(x, mean, scale_param, reduce):
423 | assert x.dim() == 4
424 | assert x.size() == mean.size()
425 | if scale_param.numel() == 1:
426 | scale_param = scale_param.view(1, 1, 1, 1)
427 | if reduce not in ['mean', 'sum', 'none']:
428 | msg = "unrecognized reduction method '{}'".format(reduce)
429 | raise RuntimeError(msg)
430 | return scale_param
431 |
432 |
433 | if __name__ == '__main__':
434 |
435 | import seaborn as sns
436 | sns.set()
437 |
438 | # *** Test discretized logistic likelihood and plot examples
439 |
440 | # Fix predicted distribution, change true data from 0 to 1:
441 | # show log probability of given distribution on the range [0, 1]
442 | t = torch.arange(0., 1., 1 / 10000).view(-1, 1, 1, 1)
443 | mean_ = torch.zeros_like(t) + 0.3
444 | logscales = np.arange(-7., 0., 1.)
445 | plt.figure(figsize=(15, 8))
446 | for logscale_ in logscales:
447 | logscale = torch.tensor(logscale_).expand_as(t)
448 | log_prob = log_discretized_logistic(t,
449 | mean_,
450 | logscale,
451 | n_bins=256,
452 | reduce='none',
453 | double=True)
454 | plt.plot(t.flatten().numpy(),
455 | log_prob.numpy(),
456 | label='logscale={}'.format(logscale_))
457 | plt.xlabel('data (x)')
458 | plt.ylabel('logprob')
459 | plt.title('log DiscrLogistic(x | 0.3, scale)')
460 | plt.legend()
461 | plt.show()
462 |
463 | # Fix true data, change distribution:
464 | # show log probability of fixed data under different distributions
465 | logscales = np.arange(-7., 0., 1.)
466 | mean_ = torch.arange(0., 1., 1 / 10000).view(-1, 1, 1, 1)
467 | t = torch.tensor(0.3).expand_as(mean_)
468 | plt.figure(figsize=(15, 8))
469 | for logscale_ in logscales:
470 | logscale = torch.tensor(logscale_).expand_as(mean_)
471 | log_prob = log_discretized_logistic(t,
472 | mean_,
473 | logscale,
474 | n_bins=256,
475 | reduce='none',
476 | double=True)
477 | plt.plot(mean_.flatten().numpy(),
478 | log_prob.numpy(),
479 | label='logscale={}'.format(logscale_))
480 | plt.xlabel('mean of logistic')
481 | plt.ylabel('logprob')
482 | plt.title('log DiscrLogistic(0.3 | mean, scale)')
483 | plt.legend()
484 | plt.show()
485 |
--------------------------------------------------------------------------------
/lib/nn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class ResidualBlock(nn.Module):
6 | """
7 | Residual block with 2 convolutional layers.
8 | Input, intermediate, and output channels are the same. Padding is always
9 | 'same'. The 2 convolutional layers have the same groups. No stride allowed,
10 | and kernel sizes have to be odd.
11 |
12 | The result is:
13 | out = gate(f(x)) + x
14 | where an argument controls the presence of the gating mechanism, and f(x)
15 | has different structures depending on the argument block_type.
16 | block_type is a string specifying the structure of the block, where:
17 | a = activation
18 | b = batch norm
19 | c = conv layer
20 | d = dropout.
21 | For example, bacdbacd has 2x (batchnorm, activation, conv, dropout).
22 | """
23 |
24 | default_kernel_size = (3, 3)
25 |
26 | def __init__(self,
27 | channels,
28 | nonlin,
29 | kernel=None,
30 | groups=1,
31 | batchnorm=True,
32 | block_type=None,
33 | dropout=None,
34 | gated=None):
35 | super().__init__()
36 | if kernel is None:
37 | kernel = self.default_kernel_size
38 | elif isinstance(kernel, int):
39 | kernel = (kernel, kernel)
40 | elif len(kernel) != 2:
41 | raise ValueError(
42 | "kernel has to be None, int, or an iterable of length 2")
43 | assert all([k % 2 == 1 for k in kernel]), "kernel sizes have to be odd"
44 | kernel = list(kernel)
45 | pad = [k // 2 for k in kernel]
46 | self.gated = gated
47 |
48 | modules = []
49 |
50 | if block_type == 'cabdcabd':
51 | for i in range(2):
52 | conv = nn.Conv2d(channels,
53 | channels,
54 | kernel[i],
55 | padding=pad[i],
56 | groups=groups)
57 | modules.append(conv)
58 | modules.append(nonlin())
59 | if batchnorm:
60 | modules.append(nn.BatchNorm2d(channels))
61 | if dropout is not None:
62 | modules.append(nn.Dropout2d(dropout))
63 |
64 | elif block_type == 'bacdbac':
65 | for i in range(2):
66 | if batchnorm:
67 | modules.append(nn.BatchNorm2d(channels))
68 | modules.append(nonlin())
69 | conv = nn.Conv2d(channels,
70 | channels,
71 | kernel[i],
72 | padding=pad[i],
73 | groups=groups)
74 | modules.append(conv)
75 | if dropout is not None and i == 0:
76 | modules.append(nn.Dropout2d(dropout))
77 |
78 | elif block_type == 'bacdbacd':
79 | for i in range(2):
80 | if batchnorm:
81 | modules.append(nn.BatchNorm2d(channels))
82 | modules.append(nonlin())
83 | conv = nn.Conv2d(channels,
84 | channels,
85 | kernel[i],
86 | padding=pad[i],
87 | groups=groups)
88 | modules.append(conv)
89 | modules.append(nn.Dropout2d(dropout))
90 |
91 | else:
92 | raise ValueError("unrecognized block type '{}'".format(block_type))
93 |
94 | if gated:
95 | modules.append(GateLayer2d(channels, 1, nonlin))
96 | self.block = nn.Sequential(*modules)
97 |
98 | def forward(self, x):
99 | return self.block(x) + x
100 |
101 |
102 | class ResidualGatedBlock(ResidualBlock):
103 |
104 | def __init__(self, *args, **kwargs):
105 | super().__init__(*args, **kwargs, gated=True)
106 |
107 |
108 | class GateLayer2d(nn.Module):
109 | """
110 | Double the number of channels through a convolutional layer, then use
111 | half the channels as gate for the other half.
112 | """
113 |
114 | def __init__(self, channels, kernel_size, nonlin=nn.LeakyReLU):
115 | super().__init__()
116 | assert kernel_size % 2 == 1
117 | pad = kernel_size // 2
118 | self.conv = nn.Conv2d(channels, 2 * channels, kernel_size, padding=pad)
119 | self.nonlin = nonlin()
120 |
121 | def forward(self, x):
122 | x = self.conv(x)
123 | x, gate = torch.chunk(x, 2, dim=1)
124 | x = self.nonlin(x) # TODO remove this?
125 | gate = torch.sigmoid(gate)
126 | return x * gate
127 |
--------------------------------------------------------------------------------
/lib/stochastic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.distributions import kl_divergence
4 | from torch.distributions.normal import Normal
5 |
6 |
7 | class NormalStochasticBlock2d(nn.Module):
8 | """
9 | Transform input parameters to q(z) with a convolution, optionally do the
10 | same for p(z), then sample z ~ q(z) and return conv(z).
11 |
12 | If q's parameters are not given, do the same but sample from p(z).
13 | """
14 |
15 | def __init__(self, c_in, c_vars, c_out, kernel=3, transform_p_params=True):
16 | super().__init__()
17 | assert kernel % 2 == 1
18 | pad = kernel // 2
19 | self.transform_p_params = transform_p_params
20 | self.c_in = c_in
21 | self.c_out = c_out
22 | self.c_vars = c_vars
23 |
24 | if transform_p_params:
25 | self.conv_in_p = nn.Conv2d(c_in, 2 * c_vars, kernel, padding=pad)
26 | self.conv_in_q = nn.Conv2d(c_in, 2 * c_vars, kernel, padding=pad)
27 | self.conv_out = nn.Conv2d(c_vars, c_out, kernel, padding=pad)
28 |
29 | def forward(self,
30 | p_params,
31 | q_params=None,
32 | forced_latent=None,
33 | use_mode=False,
34 | force_constant_output=False,
35 | analytical_kl=False):
36 |
37 | assert (forced_latent is None) or (not use_mode)
38 |
39 | if self.transform_p_params:
40 | p_params = self.conv_in_p(p_params)
41 | else:
42 | assert p_params.size(1) == 2 * self.c_vars
43 |
44 | # Define p(z)
45 | p_mu, p_lv = p_params.chunk(2, dim=1)
46 | p = Normal(p_mu, (p_lv / 2).exp())
47 |
48 | if q_params is not None:
49 | # Define q(z)
50 | q_params = self.conv_in_q(q_params)
51 | q_mu, q_lv = q_params.chunk(2, dim=1)
52 | q = Normal(q_mu, (q_lv / 2).exp())
53 |
54 | # Sample from q(z)
55 | sampling_distrib = q
56 | else:
57 | # Sample from p(z)
58 | sampling_distrib = p
59 |
60 | # Generate latent variable (typically by sampling)
61 | if forced_latent is None:
62 | if use_mode:
63 | z = sampling_distrib.mean
64 | else:
65 | z = sampling_distrib.rsample()
66 | else:
67 | z = forced_latent
68 |
69 | # Copy one sample (and distrib parameters) over the whole batch.
70 | # This is used when doing experiment from the prior - q is not used.
71 | if force_constant_output:
72 | z = z[0:1].expand_as(z).clone()
73 | p_params = p_params[0:1].expand_as(p_params).clone()
74 |
75 | # Output of stochastic layer
76 | out = self.conv_out(z)
77 |
78 | # Compute log p(z)
79 | logprob_p = p.log_prob(z).sum((1, 2, 3))
80 |
81 | if q_params is not None:
82 |
83 | # Compute log q(z)
84 | logprob_q = q.log_prob(z).sum((1, 2, 3))
85 |
86 | # Compute KL (analytical or MC estimate)
87 | kl_analytical = kl_divergence(q, p)
88 | if analytical_kl:
89 | kl_elementwise = kl_analytical
90 | else:
91 | kl_elementwise = kl_normal_mc(z, p_params, q_params)
92 | kl_samplewise = kl_elementwise.sum((1, 2, 3))
93 |
94 | # Compute spatial KL analytically (but conditioned on samples from
95 | # previous layers)
96 | kl_spatial_analytical = kl_analytical.sum(1)
97 |
98 | else:
99 | kl_elementwise = kl_samplewise = kl_spatial_analytical = None
100 | logprob_q = None
101 |
102 | data = {
103 | 'z': z, # sampled variable at this layer (batch, ch, h, w)
104 | 'p_params': p_params, # (b, ch, h, w) where b is 1 or batch size
105 | 'q_params': q_params, # (batch, ch, h, w)
106 | 'logprob_p': logprob_p, # (batch, )
107 | 'logprob_q': logprob_q, # (batch, )
108 | 'kl_elementwise': kl_elementwise, # (batch, ch, h, w)
109 | 'kl_samplewise': kl_samplewise, # (batch, )
110 | 'kl_spatial': kl_spatial_analytical, # (batch, h, w)
111 | }
112 | return out, data
113 |
114 |
115 | def logistic_rsample(mu_ls):
116 | """
117 | Returns a sample from Logistic with specified mean and log scale.
118 | :param mu_ls: a tensor containing mean and log scale along dim=1,
119 | or a tuple (mean, log scale)
120 | :return: a reparameterized sample with the same size as the input
121 | mean and log scale
122 | """
123 |
124 | # Get parameters
125 | try:
126 | mu, log_scale = torch.chunk(mu_ls, 2, dim=1)
127 | except TypeError:
128 | mu, log_scale = mu_ls
129 | scale = log_scale.exp()
130 |
131 | # Get uniform sample in open interval (0, 1)
132 | u = torch.zeros_like(mu)
133 | u.uniform_(1e-7, 1 - 1e-7)
134 |
135 | # Transform into logistic sample
136 | sample = mu + scale * (torch.log(u) - torch.log(1 - u))
137 |
138 | return sample
139 |
140 |
141 | def sample_from_discretized_mix_logistic(l):
142 | """
143 | Code taken from pytorch adaptation of original PixelCNN++ tf implementation
144 | https://github.com/pclucas14/pixel-cnn-pp
145 | """
146 |
147 | def to_one_hot(tensor, n):
148 | one_hot = torch.zeros(tensor.size() + (n,))
149 | one_hot = one_hot.to(tensor.device)
150 | one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), 1.)
151 | return one_hot
152 |
153 | # Pytorch ordering
154 | l = l.permute(0, 2, 3, 1)
155 | ls = [int(y) for y in l.size()]
156 | xs = ls[:-1] + [3]
157 |
158 | # here and below: unpacking the params of the mixture of logistics
159 | nr_mix = int(ls[-1] / 10)
160 |
161 | # unpack parameters
162 | logit_probs = l[:, :, :, :nr_mix]
163 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3])
164 | # sample mixture indicator from softmax
165 | temp = torch.FloatTensor(logit_probs.size())
166 | if l.is_cuda:
167 | temp = temp.cuda()
168 | temp.uniform_(1e-5, 1. - 1e-5)
169 | temp = logit_probs.data - torch.log(-torch.log(temp))
170 | _, argmax = temp.max(dim=3)
171 |
172 | one_hot = to_one_hot(argmax, nr_mix)
173 | sel = one_hot.view(xs[:-1] + [1, nr_mix])
174 | # select logistic parameters
175 | means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4)
176 | log_scales = torch.clamp(torch.sum(l[:, :, :, :, nr_mix:2 * nr_mix] * sel,
177 | dim=4),
178 | min=-7.)
179 | coeffs = torch.sum(torch.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) * sel,
180 | dim=4)
181 | # sample from logistic & clip to interval
182 | # we don't actually round to the nearest 8bit value when sampling
183 | u = torch.FloatTensor(means.size())
184 | if l.is_cuda:
185 | u = u.cuda()
186 | u.uniform_(1e-5, 1. - 1e-5)
187 | u = nn.Parameter(u)
188 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
189 | x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-1.), max=1.)
190 | x1 = torch.clamp(torch.clamp(x[:, :, :, 1] + coeffs[:, :, :, 0] * x0,
191 | min=-1.),
192 | max=1.)
193 | x2 = torch.clamp(torch.clamp(x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 +
194 | coeffs[:, :, :, 2] * x1,
195 | min=-1.),
196 | max=1.)
197 |
198 | out = torch.cat([
199 | x0.view(xs[:-1] + [1]),
200 | x1.view(xs[:-1] + [1]),
201 | x2.view(xs[:-1] + [1])
202 | ],
203 | dim=3)
204 | # put back in Pytorch ordering
205 | out = out.permute(0, 3, 1, 2)
206 | return out
207 |
208 |
209 | def kl_normal_mc(z, p_mulv, q_mulv):
210 | """
211 | One-sample estimation of element-wise KL between two diagonal
212 | multivariate normal distributions. Any number of dimensions,
213 | broadcasting supported (be careful).
214 |
215 | :param z:
216 | :param p_mulv:
217 | :param q_mulv:
218 | :return:
219 | """
220 | p_mu, p_lv = torch.chunk(p_mulv, 2, dim=1)
221 | q_mu, q_lv = torch.chunk(q_mulv, 2, dim=1)
222 | p_std = (p_lv / 2).exp()
223 | q_std = (q_lv / 2).exp()
224 | p_distrib = Normal(p_mu, p_std)
225 | q_distrib = Normal(q_mu, q_std)
226 | return q_distrib.log_prob(z) - p_distrib.log_prob(z)
227 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from boilr import Trainer
2 |
3 | from experiment import LVAEExperiment
4 |
5 |
6 | def main():
7 | experiment = LVAEExperiment()
8 | trainer = Trainer(experiment)
9 | trainer.run()
10 |
11 |
12 | if __name__ == "__main__":
13 | main()
14 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/models/__init__.py
--------------------------------------------------------------------------------
/models/lvae.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from boilr.models import BaseGenerativeModel
4 | from boilr.nn import crop_img_tensor, pad_img_tensor, Interpolate, free_bits_kl
5 | from torch import nn
6 |
7 | from lib.likelihoods import (BernoulliLikelihood, GaussianLikelihood,
8 | DiscretizedLogisticLikelihood,
9 | DiscretizedLogisticMixLikelihood)
10 | from .lvae_layers import (TopDownLayer, BottomUpLayer,
11 | TopDownDeterministicResBlock,
12 | BottomUpDeterministicResBlock)
13 |
14 |
15 | class LadderVAE(BaseGenerativeModel):
16 |
17 | def __init__(self,
18 | color_ch,
19 | z_dims,
20 | blocks_per_layer=2,
21 | downsample=None,
22 | nonlin='elu',
23 | merge_type=None,
24 | batchnorm=True,
25 | stochastic_skip=False,
26 | n_filters=32,
27 | dropout=None,
28 | free_bits=0.0,
29 | learn_top_prior=False,
30 | img_shape=None,
31 | likelihood_form=None,
32 | res_block_type=None,
33 | gated=False,
34 | no_initial_downscaling=False,
35 | analytical_kl=False):
36 | super().__init__()
37 | self.color_ch = color_ch
38 | self.z_dims = z_dims
39 | self.blocks_per_layer = blocks_per_layer
40 | self.downsample = downsample
41 | self.n_layers = len(self.z_dims)
42 | self.stochastic_skip = stochastic_skip
43 | self.n_filters = n_filters
44 | self.dropout = dropout
45 | self.free_bits = free_bits
46 | self.learn_top_prior = learn_top_prior
47 | self.img_shape = tuple(img_shape)
48 | self.res_block_type = res_block_type
49 | self.gated = gated
50 |
51 | # Default: no downsampling (except for initial bottom-up block)
52 | if self.downsample is None:
53 | self.downsample = [0] * self.n_layers
54 |
55 | # Downsample by a factor of 2 at each downsampling operation
56 | self.overall_downscale_factor = np.power(2, sum(self.downsample))
57 | if not no_initial_downscaling: # by default do another downscaling
58 | self.overall_downscale_factor *= 2
59 |
60 | assert max(self.downsample) <= self.blocks_per_layer
61 | assert len(self.downsample) == self.n_layers
62 |
63 | # Get class of nonlinear activation from string description
64 | nonlin = {
65 | 'relu': nn.ReLU,
66 | 'leakyrelu': nn.LeakyReLU,
67 | 'elu': nn.ELU,
68 | 'selu': nn.SELU,
69 | }[nonlin]
70 |
71 | # First bottom-up layer: change num channels + downsample by factor 2
72 | # unless we want to prevent this
73 | stride = 1 if no_initial_downscaling else 2
74 | self.first_bottom_up = nn.Sequential(
75 | nn.Conv2d(color_ch, n_filters, 5, padding=2, stride=stride),
76 | nonlin(),
77 | BottomUpDeterministicResBlock(
78 | c_in=n_filters,
79 | c_out=n_filters,
80 | nonlin=nonlin,
81 | batchnorm=batchnorm,
82 | dropout=dropout,
83 | res_block_type=res_block_type,
84 | ))
85 |
86 | # Init lists of layers
87 | self.top_down_layers = nn.ModuleList([])
88 | self.bottom_up_layers = nn.ModuleList([])
89 |
90 | for i in range(self.n_layers):
91 |
92 | # Whether this is the top layer
93 | is_top = i == self.n_layers - 1
94 |
95 | # Add bottom-up deterministic layer at level i.
96 | # It's a sequence of residual blocks (BottomUpDeterministicResBlock)
97 | # possibly with downsampling between them.
98 | self.bottom_up_layers.append(
99 | BottomUpLayer(
100 | n_res_blocks=self.blocks_per_layer,
101 | n_filters=n_filters,
102 | downsampling_steps=self.downsample[i],
103 | nonlin=nonlin,
104 | batchnorm=batchnorm,
105 | dropout=dropout,
106 | res_block_type=res_block_type,
107 | gated=gated,
108 | ))
109 |
110 | # Add top-down stochastic layer at level i.
111 | # The architecture when doing inference is roughly as follows:
112 | # p_params = output of top-down layer above
113 | # bu = inferred bottom-up value at this layer
114 | # q_params = merge(bu, p_params)
115 | # z = stochastic_layer(q_params)
116 | # possibly get skip connection from previous top-down layer
117 | # top-down deterministic ResNet
118 | #
119 | # When doing generation only, the value bu is not available, the
120 | # merge layer is not used, and z is sampled directly from p_params.
121 | #
122 | self.top_down_layers.append(
123 | TopDownLayer(
124 | z_dim=z_dims[i],
125 | n_res_blocks=blocks_per_layer,
126 | n_filters=n_filters,
127 | is_top_layer=is_top,
128 | downsampling_steps=downsample[i],
129 | nonlin=nonlin,
130 | merge_type=merge_type,
131 | batchnorm=batchnorm,
132 | dropout=dropout,
133 | stochastic_skip=stochastic_skip,
134 | learn_top_prior=learn_top_prior,
135 | top_prior_param_shape=self.get_top_prior_param_shape(),
136 | res_block_type=res_block_type,
137 | gated=gated,
138 | analytical_kl=analytical_kl,
139 | ))
140 |
141 | # Final top-down layer
142 | modules = list()
143 | if not no_initial_downscaling:
144 | modules.append(Interpolate(scale=2))
145 | for i in range(blocks_per_layer):
146 | modules.append(
147 | TopDownDeterministicResBlock(
148 | c_in=n_filters,
149 | c_out=n_filters,
150 | nonlin=nonlin,
151 | batchnorm=batchnorm,
152 | dropout=dropout,
153 | res_block_type=res_block_type,
154 | gated=gated,
155 | ))
156 | self.final_top_down = nn.Sequential(*modules)
157 |
158 | # Define likelihood
159 | if likelihood_form == 'bernoulli':
160 | self.likelihood = BernoulliLikelihood(n_filters, color_ch)
161 | elif likelihood_form == 'gaussian':
162 | self.likelihood = GaussianLikelihood(n_filters, color_ch)
163 | elif likelihood_form == 'discr_log':
164 | self.likelihood = DiscretizedLogisticLikelihood(
165 | n_filters, color_ch, 256)
166 | elif likelihood_form == 'discr_log_mix':
167 | self.likelihood = DiscretizedLogisticMixLikelihood(n_filters)
168 | else:
169 | msg = "Unrecognized likelihood '{}'".format(likelihood_form)
170 | raise RuntimeError(msg)
171 |
172 | def forward(self, x):
173 | img_size = x.size()[2:]
174 |
175 | # Pad input to make everything easier with conv strides
176 | x_pad = self.pad_input(x)
177 |
178 | # Bottom-up inference: return list of length n_layers (bottom to top)
179 | bu_values = self.bottomup_pass(x_pad)
180 |
181 | # Top-down inference/generation
182 | out, td_data = self.topdown_pass(bu_values)
183 |
184 | # Restore original image size
185 | out = crop_img_tensor(out, img_size)
186 |
187 | # Log likelihood and other info (per data point)
188 | ll, likelihood_info = self.likelihood(out, x)
189 |
190 | # kl[i] for each i has length batch_size
191 | # resulting kl shape: (batch_size, layers)
192 | kl = torch.cat([kl_layer.unsqueeze(1) for kl_layer in td_data['kl']],
193 | dim=1)
194 |
195 | kl_sep = kl.sum(1)
196 | kl_avg_layerwise = kl.mean(0)
197 | kl_loss = free_bits_kl(kl, self.free_bits).sum() # sum over layers
198 | kl = kl_sep.mean()
199 |
200 | output = {
201 | 'll': ll,
202 | 'z': td_data['z'],
203 | 'kl': kl,
204 | 'kl_sep': kl_sep,
205 | 'kl_avg_layerwise': kl_avg_layerwise,
206 | 'kl_spatial': td_data['kl_spatial'],
207 | 'kl_loss': kl_loss,
208 | 'logp': td_data['logprob_p'],
209 | 'out_mean': likelihood_info['mean'],
210 | 'out_mode': likelihood_info['mode'],
211 | 'out_sample': likelihood_info['sample'],
212 | 'likelihood_params': likelihood_info['params']
213 | }
214 | return output
215 |
216 | def bottomup_pass(self, x):
217 | # Bottom-up initial layer
218 | x = self.first_bottom_up(x)
219 |
220 | # Loop from bottom to top layer, store all deterministic nodes we
221 | # need in the top-down pass
222 | bu_values = []
223 | for i in range(self.n_layers):
224 | x = self.bottom_up_layers[i](x)
225 | bu_values.append(x)
226 |
227 | return bu_values
228 |
229 | def topdown_pass(self,
230 | bu_values=None,
231 | n_img_prior=None,
232 | mode_layers=None,
233 | constant_layers=None,
234 | forced_latent=None):
235 |
236 | # Default: no layer is sampled from the distribution's mode
237 | if mode_layers is None:
238 | mode_layers = []
239 | if constant_layers is None:
240 | constant_layers = []
241 | prior_experiment = len(mode_layers) > 0 or len(constant_layers) > 0
242 |
243 | # If the bottom-up inference values are not given, don't do
244 | # inference, sample from prior instead
245 | inference_mode = bu_values is not None
246 |
247 | # Check consistency of arguments
248 | if inference_mode != (n_img_prior is None):
249 | msg = ("Number of images for top-down generation has to be given "
250 | "if and only if we're not doing inference")
251 | raise RuntimeError(msg)
252 | if inference_mode and prior_experiment:
253 | msg = ("Prior experiments (e.g. sampling from mode) are not"
254 | " compatible with inference mode")
255 | raise RuntimeError(msg)
256 |
257 | # Sampled latent variables at each layer
258 | z = [None] * self.n_layers
259 |
260 | # KL divergence of each layer
261 | kl = [None] * self.n_layers
262 |
263 | # Spatial map of KL divergence for each layer
264 | kl_spatial = [None] * self.n_layers
265 |
266 | if forced_latent is None:
267 | forced_latent = [None] * self.n_layers
268 |
269 | # log p(z) where z is the sample in the topdown pass
270 | logprob_p = 0.
271 |
272 | # Top-down inference/generation loop
273 | out = out_pre_residual = None
274 | for i in reversed(range(self.n_layers)):
275 |
276 | # If available, get deterministic node from bottom-up inference
277 | try:
278 | bu_value = bu_values[i]
279 | except TypeError:
280 | bu_value = None
281 |
282 | # Whether the current layer should be sampled from the mode
283 | use_mode = i in mode_layers
284 | constant_out = i in constant_layers
285 |
286 | # Input for skip connection
287 | skip_input = out # TODO or out_pre_residual? or both?
288 |
289 | # Full top-down layer, including sampling and deterministic part
290 | out, out_pre_residual, aux = self.top_down_layers[i](
291 | out,
292 | skip_connection_input=skip_input,
293 | inference_mode=inference_mode,
294 | bu_value=bu_value,
295 | n_img_prior=n_img_prior,
296 | use_mode=use_mode,
297 | force_constant_output=constant_out,
298 | forced_latent=forced_latent[i],
299 | )
300 | z[i] = aux['z'] # sampled variable at this layer (batch, ch, h, w)
301 | kl[i] = aux['kl_samplewise'] # (batch, )
302 | kl_spatial[i] = aux['kl_spatial'] # (batch, h, w)
303 | logprob_p += aux['logprob_p'].mean() # mean over batch
304 |
305 | # Final top-down layer
306 | out = self.final_top_down(out)
307 |
308 | data = {
309 | 'z': z, # list of tensors with shape (batch, ch[i], h[i], w[i])
310 | 'kl': kl, # list of tensors with shape (batch, )
311 | 'kl_spatial':
312 | kl_spatial, # list of tensors w shape (batch, h[i], w[i])
313 | 'logprob_p': logprob_p, # scalar, mean over batch
314 | }
315 | return out, data
316 |
317 | def pad_input(self, x):
318 | """
319 | Pads input x so that its sizes are powers of 2
320 | :param x:
321 | :return: Padded tensor
322 | """
323 | size = self.get_padded_size(x.size())
324 | x = pad_img_tensor(x, size)
325 | return x
326 |
327 | def get_padded_size(self, size):
328 | """
329 | Returns the smallest size (H, W) of the image with actual size given
330 | as input, such that H and W are powers of 2.
331 | :param size: input size, tuple either (N, C, H, w) or (H, W)
332 | :return: 2-tuple (H, W)
333 | """
334 |
335 | # Overall downscale factor from input to top layer (power of 2)
336 | dwnsc = self.overall_downscale_factor
337 |
338 | # Make size argument into (heigth, width)
339 | if len(size) == 4:
340 | size = size[2:]
341 | if len(size) != 2:
342 | msg = ("input size must be either (N, C, H, W) or (H, W), but it "
343 | "has length {} (size={})".format(len(size), size))
344 | raise RuntimeError(msg)
345 |
346 | # Output smallest powers of 2 that are larger than current sizes
347 | padded_size = list(((s - 1) // dwnsc + 1) * dwnsc for s in size)
348 |
349 | return padded_size
350 |
351 | def sample_prior(self, n_imgs, mode_layers=None, constant_layers=None):
352 |
353 | # Generate from prior
354 | out, _ = self.topdown_pass(n_img_prior=n_imgs,
355 | mode_layers=mode_layers,
356 | constant_layers=constant_layers)
357 | out = crop_img_tensor(out, self.img_shape)
358 |
359 | # Log likelihood and other info (per data point)
360 | _, likelihood_data = self.likelihood(out, None)
361 |
362 | return likelihood_data['sample']
363 |
364 | def get_top_prior_param_shape(self, n_imgs=1):
365 | # TODO num channels depends on random variable we're using
366 | dwnsc = self.overall_downscale_factor
367 | sz = self.get_padded_size(self.img_shape)
368 | h = sz[0] // dwnsc
369 | w = sz[1] // dwnsc
370 | c = self.z_dims[-1] * 2 # mu and logvar
371 | top_layer_shape = (n_imgs, c, h, w)
372 | return top_layer_shape
373 |
--------------------------------------------------------------------------------
/models/lvae_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from lib.nn import ResidualBlock, ResidualGatedBlock
5 | from lib.stochastic import NormalStochasticBlock2d
6 |
7 |
8 | class TopDownLayer(nn.Module):
9 | """
10 | Top-down layer, including stochastic sampling, KL computation, and small
11 | deterministic ResNet with upsampling.
12 |
13 | The architecture when doing inference is roughly as follows:
14 | p_params = output of top-down layer above
15 | bu = inferred bottom-up value at this layer
16 | q_params = merge(bu, p_params)
17 | z = stochastic_layer(q_params)
18 | possibly get skip connection from previous top-down layer
19 | top-down deterministic ResNet
20 |
21 | When doing generation only, the value bu is not available, the
22 | merge layer is not used, and z is sampled directly from p_params.
23 |
24 | If this is the top layer, at inference time, the uppermost bottom-up value
25 | is used directly as q_params, and p_params are defined in this layer
26 | (while they are usually taken from the previous layer), and can be learned.
27 | """
28 |
29 | def __init__(self,
30 | z_dim,
31 | n_res_blocks,
32 | n_filters,
33 | is_top_layer=False,
34 | downsampling_steps=None,
35 | nonlin=None,
36 | merge_type=None,
37 | batchnorm=True,
38 | dropout=None,
39 | stochastic_skip=False,
40 | res_block_type=None,
41 | gated=None,
42 | learn_top_prior=False,
43 | top_prior_param_shape=None,
44 | analytical_kl=False):
45 |
46 | super().__init__()
47 |
48 | self.is_top_layer = is_top_layer
49 | self.z_dim = z_dim
50 | self.stochastic_skip = stochastic_skip
51 | self.learn_top_prior = learn_top_prior
52 | self.analytical_kl = analytical_kl
53 |
54 | # Define top layer prior parameters, possibly learnable
55 | if is_top_layer:
56 | self.top_prior_params = nn.Parameter(
57 | torch.zeros(top_prior_param_shape),
58 | requires_grad=learn_top_prior)
59 |
60 | # Downsampling steps left to do in this layer
61 | dws_left = downsampling_steps
62 |
63 | # Define deterministic top-down block: sequence of deterministic
64 | # residual blocks with downsampling when needed.
65 | block_list = []
66 | for _ in range(n_res_blocks):
67 | do_resample = False
68 | if dws_left > 0:
69 | do_resample = True
70 | dws_left -= 1
71 | block_list.append(
72 | TopDownDeterministicResBlock(
73 | n_filters,
74 | n_filters,
75 | nonlin,
76 | upsample=do_resample,
77 | batchnorm=batchnorm,
78 | dropout=dropout,
79 | res_block_type=res_block_type,
80 | gated=gated,
81 | ))
82 | self.deterministic_block = nn.Sequential(*block_list)
83 |
84 | # Define stochastic block with 2d convolutions
85 | self.stochastic = NormalStochasticBlock2d(
86 | c_in=n_filters,
87 | c_vars=z_dim,
88 | c_out=n_filters,
89 | transform_p_params=(not is_top_layer),
90 | )
91 |
92 | if not is_top_layer:
93 |
94 | # Merge layer, combine bottom-up inference with top-down
95 | # generative to give posterior parameters
96 | self.merge = MergeLayer(
97 | channels=n_filters,
98 | merge_type=merge_type,
99 | nonlin=nonlin,
100 | batchnorm=batchnorm,
101 | dropout=dropout,
102 | res_block_type=res_block_type,
103 | )
104 |
105 | # Skip connection that goes around the stochastic top-down layer
106 | if stochastic_skip:
107 | self.skip_connection_merger = SkipConnectionMerger(
108 | channels=n_filters,
109 | nonlin=nonlin,
110 | batchnorm=batchnorm,
111 | dropout=dropout,
112 | res_block_type=res_block_type,
113 | )
114 |
115 | def forward(self,
116 | input_=None,
117 | skip_connection_input=None,
118 | inference_mode=False,
119 | bu_value=None,
120 | n_img_prior=None,
121 | forced_latent=None,
122 | use_mode=False,
123 | force_constant_output=False):
124 |
125 | # Check consistency of arguments
126 | inputs_none = input_ is None and skip_connection_input is None
127 | if self.is_top_layer and not inputs_none:
128 | raise ValueError("In top layer, inputs should be None")
129 |
130 | # If top layer, define parameters of prior p(z_L)
131 | if self.is_top_layer:
132 | p_params = self.top_prior_params
133 |
134 | # Sample specific number of images by expanding the prior
135 | if n_img_prior is not None:
136 | p_params = p_params.expand(n_img_prior, -1, -1, -1)
137 |
138 | # Else the input from the layer above is the prior parameters
139 | else:
140 | p_params = input_
141 |
142 | # In inference mode, get parameters of q from inference path,
143 | # merging with top-down path if it's not the top layer
144 | if inference_mode:
145 | if self.is_top_layer:
146 | q_params = bu_value
147 | else:
148 | q_params = self.merge(bu_value, p_params)
149 |
150 | # In generative mode, q is not used
151 | else:
152 | q_params = None
153 |
154 | # Sample from either q(z_i | z_{i+1}, x) or p(z_i | z_{i+1})
155 | # depending on whether q_params is None
156 | x, data_stoch = self.stochastic(
157 | p_params=p_params,
158 | q_params=q_params,
159 | forced_latent=forced_latent,
160 | use_mode=use_mode,
161 | force_constant_output=force_constant_output,
162 | analytical_kl=self.analytical_kl,
163 | )
164 |
165 | # Skip connection from previous layer
166 | if self.stochastic_skip and not self.is_top_layer:
167 | x = self.skip_connection_merger(x, skip_connection_input)
168 |
169 | # Save activation before residual block: could be the skip
170 | # connection input in the next layer
171 | x_pre_residual = x
172 |
173 | # Last top-down block (sequence of residual blocks)
174 | x = self.deterministic_block(x)
175 |
176 | keys = ['z', 'kl_samplewise', 'kl_spatial', 'logprob_p', 'logprob_q']
177 | data = {k: data_stoch[k] for k in keys}
178 | return x, x_pre_residual, data
179 |
180 |
181 | class BottomUpLayer(nn.Module):
182 | """
183 | Bottom-up deterministic layer for inference, roughly the same as the
184 | small deterministic Resnet in top-down layers. Consists of a sequence of
185 | bottom-up deterministic residual blocks with downsampling.
186 | """
187 |
188 | def __init__(self,
189 | n_res_blocks,
190 | n_filters,
191 | downsampling_steps=0,
192 | nonlin=None,
193 | batchnorm=True,
194 | dropout=None,
195 | res_block_type=None,
196 | gated=None):
197 | super().__init__()
198 |
199 | bu_blocks = []
200 | for _ in range(n_res_blocks):
201 | do_resample = False
202 | if downsampling_steps > 0:
203 | do_resample = True
204 | downsampling_steps -= 1
205 | bu_blocks.append(
206 | BottomUpDeterministicResBlock(
207 | c_in=n_filters,
208 | c_out=n_filters,
209 | nonlin=nonlin,
210 | downsample=do_resample,
211 | batchnorm=batchnorm,
212 | dropout=dropout,
213 | res_block_type=res_block_type,
214 | gated=gated,
215 | ))
216 | self.net = nn.Sequential(*bu_blocks)
217 |
218 | def forward(self, x):
219 | return self.net(x)
220 |
221 |
222 | class ResBlockWithResampling(nn.Module):
223 | """
224 | Residual block that takes care of resampling steps (each by a factor of 2).
225 |
226 | The mode can be top-down or bottom-up, and the block does up- and
227 | down-sampling by a factor of 2, respectively. Resampling is performed at
228 | the beginning of the block, through strided convolution.
229 |
230 | The number of channels is adjusted at the beginning and end of the block,
231 | through convolutional layers with kernel size 1. The number of internal
232 | channels is by default the same as the number of output channels, but
233 | min_inner_channels overrides this behaviour.
234 |
235 | Other parameters: kernel size, nonlinearity, and groups of the internal
236 | residual block; whether batch normalization and dropout are performed;
237 | whether the residual path has a gate layer at the end. There are a few
238 | residual block structures to choose from.
239 | """
240 |
241 | def __init__(self,
242 | mode,
243 | c_in,
244 | c_out,
245 | nonlin=nn.LeakyReLU,
246 | resample=False,
247 | res_block_kernel=None,
248 | groups=1,
249 | batchnorm=True,
250 | res_block_type=None,
251 | dropout=None,
252 | min_inner_channels=None,
253 | gated=None):
254 | super().__init__()
255 | assert mode in ['top-down', 'bottom-up']
256 | if min_inner_channels is None:
257 | min_inner_channels = 0
258 | inner_filters = max(c_out, min_inner_channels)
259 |
260 | # Define first conv layer to change channels and/or up/downsample
261 | if resample:
262 | if mode == 'bottom-up': # downsample
263 | self.pre_conv = nn.Conv2d(in_channels=c_in,
264 | out_channels=inner_filters,
265 | kernel_size=3,
266 | padding=1,
267 | stride=2,
268 | groups=groups)
269 | elif mode == 'top-down': # upsample
270 | self.pre_conv = nn.ConvTranspose2d(in_channels=c_in,
271 | out_channels=inner_filters,
272 | kernel_size=3,
273 | padding=1,
274 | stride=2,
275 | groups=groups,
276 | output_padding=1)
277 | elif c_in != inner_filters:
278 | self.pre_conv = nn.Conv2d(c_in, inner_filters, 1, groups=groups)
279 | else:
280 | self.pre_conv = None
281 |
282 | # Residual block
283 | self.res = ResidualBlock(
284 | channels=inner_filters,
285 | nonlin=nonlin,
286 | kernel=res_block_kernel,
287 | groups=groups,
288 | batchnorm=batchnorm,
289 | dropout=dropout,
290 | gated=gated,
291 | block_type=res_block_type,
292 | )
293 |
294 | # Define last conv layer to get correct num output channels
295 | if inner_filters != c_out:
296 | self.post_conv = nn.Conv2d(inner_filters, c_out, 1, groups=groups)
297 | else:
298 | self.post_conv = None
299 |
300 | def forward(self, x):
301 | if self.pre_conv is not None:
302 | x = self.pre_conv(x)
303 | x = self.res(x)
304 | if self.post_conv is not None:
305 | x = self.post_conv(x)
306 | return x
307 |
308 |
309 | class TopDownDeterministicResBlock(ResBlockWithResampling):
310 |
311 | def __init__(self, *args, upsample=False, **kwargs):
312 | kwargs['resample'] = upsample
313 | super().__init__('top-down', *args, **kwargs)
314 |
315 |
316 | class BottomUpDeterministicResBlock(ResBlockWithResampling):
317 |
318 | def __init__(self, *args, downsample=False, **kwargs):
319 | kwargs['resample'] = downsample
320 | super().__init__('bottom-up', *args, **kwargs)
321 |
322 |
323 | class MergeLayer(nn.Module):
324 | """
325 | Merge two 4D input tensors by concatenating along dim=1 and passing the
326 | result through 1) a convolutional 1x1 layer, or 2) a residual block
327 | """
328 |
329 | def __init__(self,
330 | channels,
331 | merge_type,
332 | nonlin=nn.LeakyReLU,
333 | batchnorm=True,
334 | dropout=None,
335 | res_block_type=None):
336 | super().__init__()
337 | try:
338 | iter(channels)
339 | except TypeError: # it is not iterable
340 | channels = [channels] * 3
341 | else: # it is iterable
342 | if len(channels) == 1:
343 | channels = [channels[0]] * 3
344 | assert len(channels) == 3
345 |
346 | if merge_type == 'linear':
347 | self.layer = nn.Conv2d(channels[0] + channels[1], channels[2], 1)
348 | elif merge_type == 'residual':
349 | self.layer = nn.Sequential(
350 | nn.Conv2d(channels[0] + channels[1], channels[2], 1, padding=0),
351 | ResidualGatedBlock(channels[2],
352 | nonlin,
353 | batchnorm=batchnorm,
354 | dropout=dropout,
355 | block_type=res_block_type),
356 | )
357 |
358 | def forward(self, x, y):
359 | x = torch.cat((x, y), dim=1)
360 | return self.layer(x)
361 |
362 |
363 | class SkipConnectionMerger(MergeLayer):
364 | """
365 | By default for now simply a merge layer.
366 | """
367 |
368 | merge_type = 'residual'
369 |
370 | def __init__(self, channels, nonlin, batchnorm, dropout, res_block_type):
371 | super().__init__(channels,
372 | self.merge_type,
373 | nonlin,
374 | batchnorm,
375 | dropout=dropout,
376 | res_block_type=res_block_type)
377 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.18.1,<=1.19.0
2 | torch>=1.4.0,<=1.5.1
3 | torchvision>=0.5.0,<=0.6.1
4 | matplotlib>=3.1.2
5 | seaborn>=0.9.0,<=0.10.1
6 | boilr==0.7.4
7 | multiobject==0.0.3
8 |
--------------------------------------------------------------------------------