├── .gitignore ├── LICENSE ├── README.md ├── _readme_figs ├── celeba_sample_1.png ├── celeba_sample_1_downsc.png ├── cifar_sample_1.png ├── eq_layer_inspection.png ├── eqns.txt ├── layers_celeba │ ├── sample_mode_layer10.png │ ├── sample_mode_layer13.png │ ├── sample_mode_layer15.png │ ├── sample_mode_layer19.png │ └── sample_mode_layer5.png ├── layers_cifar │ ├── sample_mode_layer14.png │ ├── sample_mode_layer2.png │ ├── sample_mode_layer6.png │ └── sample_mode_layer9.png ├── layers_mnist │ ├── sample_mode_layer11.png │ ├── sample_mode_layer3.png │ ├── sample_mode_layer7.png │ └── sample_mode_layer9.png ├── layers_multidsprites │ ├── sample_mode_layer10.png │ ├── sample_mode_layer11.png │ ├── sample_mode_layer2.png │ ├── sample_mode_layer6.png │ └── sample_mode_layer9.png ├── layers_svhn │ ├── sample_mode_layer12.png │ ├── sample_mode_layer14.png │ ├── sample_mode_layer3.png │ └── sample_mode_layer9.png ├── lvae_eq.png ├── mnist_sample_1.png ├── multidsprites_sample_1_downsc.png └── svhn_sample_1.png ├── data ├── multi-dsprites-binary-rgb │ └── multi_dsprites_color_012.npz └── multi_mnist │ └── multi_binary_mnist_012.npz ├── evaluate.py ├── experiment ├── __init__.py ├── data.py └── experiment_manager.py ├── lib ├── datasets.py ├── likelihoods.py ├── nn.py └── stochastic.py ├── main.py ├── models ├── __init__.py ├── lvae.py └── lvae_layers.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .DS_Store 3 | __pycache__/ 4 | checkpoints/ 5 | data/ 6 | evaluation_results/ 7 | results/ 8 | tensorboard_logs/ 9 | *.pyc 10 | 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Andrea Dittadi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ladder Variational Autoencoders (LVAE) 2 | 3 | PyTorch implementation of Ladder Variational Autoencoders (LVAE) [1]: 4 | 5 |                  6 | ![LVAE equation](_readme_figs/lvae_eq.png) 7 | 8 | where the variational distributions _q_ at each layer are multivariate Normal 9 | with diagonal covariance. 10 | 11 | **Significant differences from [1]** include: 12 | - skip connections in the generative path: conditioning on _all_ layers above 13 | rather than only on _the_ layer above (see for example [2]) 14 | - spatial (convolutional) latent variables 15 | - free bits [3] instead of beta annealing [4] 16 | 17 | ### Install requirements and run MNIST example 18 | 19 | ``` 20 | pip install -r requirements.txt 21 | CUDA_VISIBLE_DEVICES=0 python main.py --zdims 32 32 32 --downsample 1 1 1 --nonlin elu --skip --blocks-per-layer 4 --gated --freebits 0.5 --learn-top-prior --data-dep-init --seed 42 --dataset static_mnist 22 | ``` 23 | 24 | Dependencies include [boilr](https://github.com/addtt/boiler-pytorch) (a framework 25 | for PyTorch) and [multiobject](https://github.com/addtt/multi-object-datasets) 26 | (which provides multi-object datasets with PyTorch dataloaders). 27 | 28 | 29 | 30 | ## Likelihood results 31 | 32 | Log likelihood bounds on the test set (average over 4 random seeds). 33 | 34 | | dataset | num layers | -ELBO | - log _p(x)_ ≤
[100 iws] | - log _p(x)_ ≤
[1000 iws] | 35 | | -------------------- |:----------:|:------------:|:-------------:|:--------------:| 36 | | binarized MNIST | 3 | 82.14 | 79.47 | 79.24 | 37 | | binarized MNIST | 6 | 80.74 | 78.65 | 78.52 | 38 | | binarized MNIST | 12 | 80.50 | 78.50 | 78.30 | 39 | | multi-dSprites (0-2) | 12 | 26.9 | 23.2 | | 40 | | SVHN | 15 | 4012 (1.88) | 3973 (1.87) | | 41 | | CIFAR10 | 3 | 7651 (3.59) | 7591 (3.56) | | 42 | | CIFAR10 | 6 | 7321 (3.44) | 7268 (3.41) | | 43 | | CIFAR10 | 15 | 7128 (3.35) | 7068 (3.32) | | 44 | | CelebA | 20 | 20026 (2.35) | 19913 (2.34) | | 45 | 46 | Note: 47 | - Bits per dimension in brackets. 48 | - 'iws' stands for importance weighted samples. More samples means tighter log 49 | likelihood lower bound. The bound converges to the actual log likelihood as 50 | the number of samples goes to infinity [5]. Note that the model is always 51 | trained with the ELBO (1 sample). 52 | - Each pixel in the images is modeled independently. The likelihood is Bernoulli 53 | for binary images, and discretized mixture of logistics with 10 54 | components [6] otherwise. 55 | - One day I'll get around to evaluating the IW bound on all datasets with 10000 samples. 56 | 57 | 58 | ## Supported datasets 59 | 60 | - Statically binarized MNIST [7], see Hugo Larochelle's website `http://www.cs.toronto.edu/~larocheh/public/datasets/` 61 | - [SVHN](http://ufldl.stanford.edu/housenumbers/) 62 | - [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) 63 | - [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) rescaled and cropped to 64x64 – see code for details. The path in `experiment.data.DatasetLoader` has to be modified 64 | - [binary multi-dSprites](https://github.com/addtt/multi-object-datasets): 64x64 RGB shapes (0 to 2) in each image 65 | 66 | 67 | ## Samples 68 | 69 | #### Binarized MNIST 70 | 71 | ![MNIST samples](_readme_figs/mnist_sample_1.png) 72 | 73 | #### Multi-dSprites 74 | 75 | ![multi-dSprites samples](_readme_figs/multidsprites_sample_1_downsc.png) 76 | 77 | #### SVHN 78 | 79 | ![SVHN samples](_readme_figs/svhn_sample_1.png) 80 | 81 | #### CIFAR 82 | 83 | ![CIFAR samples](_readme_figs/cifar_sample_1.png) 84 | 85 | #### CelebA 86 | 87 | ![CelebA samples](_readme_figs/celeba_sample_1_downsc.png) 88 | 89 | 90 | ## Hierarchical representations 91 | 92 | Here we try to visualize the representations learned by individual layers. 93 | We can get a rough idea of what's going on at layer _i_ as follows: 94 | 95 | - Sample latent variables from all layers above layer _i_ (Eq. 1). 96 | 97 | - With these variables fixed, take _S_ conditional samples at layer _i_ (Eq. 2). Note 98 | that they are all conditioned on the same samples. These correspond to one 99 | row in the images below. 100 | 101 | - For each of these samples (each small image in the images below), pick the 102 | mode/mean of the conditional distribution of each layer below (Eq. 3). 103 | 104 | - Finally, sample an image ***x*** given the latent variables (Eq. 4). 105 | 106 | Formally: 107 | 108 |                  109 | ![](_readme_figs/eq_layer_inspection.png) 110 | 111 | where _s_ = 1, ..., _S_ denotes the sample index. 112 | 113 | The equations above yield _S_ sample images conditioned on the same values of 114 | ***z*** for layers _i_+1 to _L_. These _S_ samples are shown in one row of the 115 | images below. 116 | Notice that samples from each row are almost identical when the variability comes 117 | from a low-level layer, as such layers mostly model local structure and details. 118 | Higher layers on the other hand model global structure, and we observe more and 119 | more variability in each row as we move to higher layers. When the sampling 120 | happens in the top layer (_i = L_), all samples are completely independent, 121 | even within a row. 122 | 123 | #### Binarized MNIST: layers 4, 8, 10, and 12 (top layer) 124 | 125 | ![MNIST layers 4](_readme_figs/layers_mnist/sample_mode_layer3.png)   126 | ![MNIST layers 8](_readme_figs/layers_mnist/sample_mode_layer7.png) 127 | 128 | ![MNIST layers 10](_readme_figs/layers_mnist/sample_mode_layer9.png)   129 | ![MNIST layers 12](_readme_figs/layers_mnist/sample_mode_layer11.png) 130 | 131 | 132 | #### SVHN: layers 4, 10, 13, and 15 (top layer) 133 | 134 | ![SVHN layers 4](_readme_figs/layers_svhn/sample_mode_layer3.png)   135 | ![SVHN layers 10](_readme_figs/layers_svhn/sample_mode_layer9.png) 136 | 137 | ![SVHN layers 13](_readme_figs/layers_svhn/sample_mode_layer12.png)   138 | ![SVHN layers 15](_readme_figs/layers_svhn/sample_mode_layer14.png) 139 | 140 | 141 | #### CIFAR: layers 3, 7, 10, and 15 (top layer) 142 | 143 | ![CIFAR layers 3](_readme_figs/layers_cifar/sample_mode_layer2.png)   144 | ![CIFAR layers 7](_readme_figs/layers_cifar/sample_mode_layer6.png) 145 | 146 | ![CIFAR layers 10](_readme_figs/layers_cifar/sample_mode_layer9.png)   147 | ![CIFAR layers 15](_readme_figs/layers_cifar/sample_mode_layer14.png) 148 | 149 | 150 | #### CelebA: layers 6, 11, 16, and 20 (top layer) 151 | 152 | ![CelebA layers 6](_readme_figs/layers_celeba/sample_mode_layer5.png) 153 | 154 | ![CelebA layers 11](_readme_figs/layers_celeba/sample_mode_layer10.png) 155 | 156 | ![CelebA layers 16](_readme_figs/layers_celeba/sample_mode_layer15.png) 157 | 158 | ![CelebA layers 20](_readme_figs/layers_celeba/sample_mode_layer19.png) 159 | 160 | 161 | #### Multi-dSprites: layers 3, 7, 10, and 12 (top layer) 162 | 163 | ![MNIST layers 4](_readme_figs/layers_multidsprites/sample_mode_layer2.png)   164 | ![MNIST layers 8](_readme_figs/layers_multidsprites/sample_mode_layer6.png) 165 | 166 | ![MNIST layers 10](_readme_figs/layers_multidsprites/sample_mode_layer9.png)   167 | ![MNIST layers 12](_readme_figs/layers_multidsprites/sample_mode_layer11.png) 168 | 169 | 170 | 171 | ## Experimental details 172 | 173 | I did not perform an extensive hyperparameter search, but this worked pretty well: 174 | 175 | - Downsampling by a factor of 2 in the beginning of inference. 176 | After that, activations are downsampled 4 times for 64x64 images (CelebA and 177 | multi-dSprites), and 3 times otherwise. The spatial size of the final feature 178 | map is always 2x2. 179 | Between these downsampling steps there is approximately the same number of 180 | stochastic layers. 181 | - 4 residual blocks between stochastic layers. Haven't tried with more 182 | than 4 though, as models become quite big and we get diminishing returns. 183 | - The deterministic parts of bottom-up and top-down architecture are (almost) 184 | perfectly mirrored for simplicity. 185 | - Stochastic layers have spatial random variables, and the number of rvs per 186 | "location" (i.e. number of channels of the feature map after sampling from a 187 | layer) is 32 in all layers. 188 | - All other feature maps in deterministic paths have 64 channels. 189 | - Skip connections in the generative model (`--skip`). 190 | - Gated residual blocks (`--gated`). 191 | - Learned prior of the top layer (`--learn-top-prior`). 192 | - A form of data-dependent initialization of weights (`--data-dep-init`). 193 | See code for details. 194 | - freebits=1.0 in experiments with more than 6 stochastic layers, and 0.5 for 195 | smaller models. 196 | - For everything else, see `_add_args()` in `experiment/experiment_manager.py`. 197 | 198 | With these settings, the number of parameters is roughly 1M per stochastic 199 | layer. I tried to control for this by experimenting e.g. with half the number 200 | of layers but twice the number of residual blocks, but it looks like the number 201 | of stochastic layers is what matters the most. 202 | 203 | 204 | ## References 205 | 206 | [1] CK Sønderby, 207 | T Raiko, 208 | L Maaløe, 209 | SK Sønderby, 210 | O Winther. 211 | _Ladder Variational Autoencoders_, NIPS 2016 212 | 213 | [2] L Maaløe, M Fraccaro, V Liévin, O Winther. 214 | _BIVA: A Very Deep Hierarchy of Latent Variables for Generative Modeling_, 215 | NeurIPS 2019 216 | 217 | [3] DP Kingma, 218 | T Salimans, 219 | R Jozefowicz, 220 | X Chen, 221 | I Sutskever, 222 | M Welling. 223 | _Improved Variational Inference with Inverse Autoregressive Flow_, 224 | NIPS 2016 225 | 226 | [4] I Higgins, L Matthey, A Pal, C Burgess, X Glorot, M Botvinick, 227 | S Mohamed, A Lerchner. 228 | _beta-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework_, 229 | ICLR 2017 230 | 231 | [5] Y Burda, RB Grosse, R Salakhutdinov. 232 | _Importance Weighted Autoencoders_, 233 | ICLR 2016 234 | 235 | [6] T Salimans, A Karpathy, X Chen, DP Kingma. 236 | _PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture Likelihood and Other Modifications_, 237 | ICLR 2017 238 | 239 | [7] H Larochelle, I Murray. 240 | _The neural autoregressive distribution estimator_, 241 | AISTATS 2011 242 | -------------------------------------------------------------------------------- /_readme_figs/celeba_sample_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/celeba_sample_1.png -------------------------------------------------------------------------------- /_readme_figs/celeba_sample_1_downsc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/celeba_sample_1_downsc.png -------------------------------------------------------------------------------- /_readme_figs/cifar_sample_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/cifar_sample_1.png -------------------------------------------------------------------------------- /_readme_figs/eq_layer_inspection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/eq_layer_inspection.png -------------------------------------------------------------------------------- /_readme_figs/eqns.txt: -------------------------------------------------------------------------------- 1 | \begin{align*} 2 | p_{\theta}(\mathbf{x}, \mathbf{z}) &= p_{\theta}(\mathbf{z}) \ p_{\theta}(\mathbf{x} \;|\; \mathbf{z})\\[.7em] 3 | p_{\theta}(\mathbf{z}) &=p_{\theta}\left(\mathbf{z}_{L}\right) \ \prod_{i=1}^{L-1} p_{\theta}\left(\mathbf{z}_{i} \;|\; \mathbf{z}_{> i}\right) \\ 4 | q_{\phi, \theta}(\mathbf{z} \;|\; \mathbf{x}) &=q_{\phi, \theta}\left(\mathbf{z}_{L} \;|\; \mathbf{x}\right) \ \prod_{i=1}^{L-1} q_{\phi, \theta}\left(\mathbf{z}_{i} \;|\; \mathbf{z}_{>i}\;, \mathbf{x} \right) 5 | \end{align*} 6 | 7 | ----- 8 | 9 | \newcommand{\z}{\mathbf{z}} 10 | \newcommand{\x}{\mathbf{x}} 11 | \newcommand{\argmax}[1]{\underset{#1}{\operatorname{arg}\operatorname{max}}\;} 12 | 13 | \begin{align} 14 | \z_{>i} &\sim p(\z_{>i}) = p(\z_L) \prod_{j=i+1}^{L-1} p(\z_j \;|\; \z_{>j})\\ 15 | \z_i^{(s)} &\sim p(\z_i \;|\; \z_{>i})\\[1em] 16 | % \z_{j}^{(s)} &= \mathbb{E}_{p(\z_{j} \;|\; \z_{j+1:i}^{(s)}, \z_{>i})} \left[ \z_{j} \right], \qquad \mbox{for } ji}) \qquad \mbox{for } ji} ) 19 | \end{align} 20 | -------------------------------------------------------------------------------- /_readme_figs/layers_celeba/sample_mode_layer10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_celeba/sample_mode_layer10.png -------------------------------------------------------------------------------- /_readme_figs/layers_celeba/sample_mode_layer13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_celeba/sample_mode_layer13.png -------------------------------------------------------------------------------- /_readme_figs/layers_celeba/sample_mode_layer15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_celeba/sample_mode_layer15.png -------------------------------------------------------------------------------- /_readme_figs/layers_celeba/sample_mode_layer19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_celeba/sample_mode_layer19.png -------------------------------------------------------------------------------- /_readme_figs/layers_celeba/sample_mode_layer5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_celeba/sample_mode_layer5.png -------------------------------------------------------------------------------- /_readme_figs/layers_cifar/sample_mode_layer14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_cifar/sample_mode_layer14.png -------------------------------------------------------------------------------- /_readme_figs/layers_cifar/sample_mode_layer2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_cifar/sample_mode_layer2.png -------------------------------------------------------------------------------- /_readme_figs/layers_cifar/sample_mode_layer6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_cifar/sample_mode_layer6.png -------------------------------------------------------------------------------- /_readme_figs/layers_cifar/sample_mode_layer9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_cifar/sample_mode_layer9.png -------------------------------------------------------------------------------- /_readme_figs/layers_mnist/sample_mode_layer11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_mnist/sample_mode_layer11.png -------------------------------------------------------------------------------- /_readme_figs/layers_mnist/sample_mode_layer3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_mnist/sample_mode_layer3.png -------------------------------------------------------------------------------- /_readme_figs/layers_mnist/sample_mode_layer7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_mnist/sample_mode_layer7.png -------------------------------------------------------------------------------- /_readme_figs/layers_mnist/sample_mode_layer9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_mnist/sample_mode_layer9.png -------------------------------------------------------------------------------- /_readme_figs/layers_multidsprites/sample_mode_layer10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_multidsprites/sample_mode_layer10.png -------------------------------------------------------------------------------- /_readme_figs/layers_multidsprites/sample_mode_layer11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_multidsprites/sample_mode_layer11.png -------------------------------------------------------------------------------- /_readme_figs/layers_multidsprites/sample_mode_layer2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_multidsprites/sample_mode_layer2.png -------------------------------------------------------------------------------- /_readme_figs/layers_multidsprites/sample_mode_layer6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_multidsprites/sample_mode_layer6.png -------------------------------------------------------------------------------- /_readme_figs/layers_multidsprites/sample_mode_layer9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_multidsprites/sample_mode_layer9.png -------------------------------------------------------------------------------- /_readme_figs/layers_svhn/sample_mode_layer12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_svhn/sample_mode_layer12.png -------------------------------------------------------------------------------- /_readme_figs/layers_svhn/sample_mode_layer14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_svhn/sample_mode_layer14.png -------------------------------------------------------------------------------- /_readme_figs/layers_svhn/sample_mode_layer3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_svhn/sample_mode_layer3.png -------------------------------------------------------------------------------- /_readme_figs/layers_svhn/sample_mode_layer9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/layers_svhn/sample_mode_layer9.png -------------------------------------------------------------------------------- /_readme_figs/lvae_eq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/lvae_eq.png -------------------------------------------------------------------------------- /_readme_figs/mnist_sample_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/mnist_sample_1.png -------------------------------------------------------------------------------- /_readme_figs/multidsprites_sample_1_downsc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/multidsprites_sample_1_downsc.png -------------------------------------------------------------------------------- /_readme_figs/svhn_sample_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/_readme_figs/svhn_sample_1.png -------------------------------------------------------------------------------- /data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz -------------------------------------------------------------------------------- /data/multi_mnist/multi_binary_mnist_012.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/data/multi_mnist/multi_binary_mnist_012.npz -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Standalone script for a couple of simple evaluations/tests of trained models. 3 | """ 4 | 5 | import argparse 6 | import os 7 | import warnings 8 | 9 | import torch 10 | import torch.utils.data 11 | from boilr.eval import BaseOfflineEvaluator 12 | from boilr.utils.viz import img_grid_pad_value 13 | from torchvision.utils import save_image 14 | 15 | from experiment.experiment_manager import LVAEExperiment 16 | 17 | 18 | class Evaluator(BaseOfflineEvaluator): 19 | 20 | def run(self): 21 | 22 | torch.set_grad_enabled(False) 23 | 24 | n = 12 25 | 26 | e = self._experiment 27 | e.model.eval() 28 | 29 | # Run evaluation and print results 30 | results = e.test_procedure(iw_samples=self.args.ll_samples) 31 | print("Eval results:\n{}".format(results)) 32 | 33 | # Save samples 34 | for i in range(self.args.prior_samples): 35 | fname = os.path.join(self._img_folder, "samples_{}.png".format(i)) 36 | e.generate_and_save_samples(fname, nrows=n) 37 | 38 | # Save input and reconstructions 39 | x, y = next(iter(e.dataloaders.test)) 40 | fname = os.path.join(self._img_folder, "reconstructions.png") 41 | e.generate_and_save_reconstructions(x, fname, nrows=n) 42 | 43 | # Inspect representations learned by each layer 44 | if self.args.inspect_layer_repr: 45 | inspect_layer_repr(e.model, self._img_folder, n=n) 46 | 47 | # @classmethod 48 | # def _define_args_defaults(cls) -> dict: 49 | # defaults = super(Evaluator, cls)._define_args_defaults() 50 | # return defaults 51 | 52 | def _add_args(self, parser: argparse.ArgumentParser) -> None: 53 | 54 | super(Evaluator, self)._add_args(parser) 55 | 56 | parser.add_argument('--ll', 57 | action='store_true', 58 | help="estimate log likelihood with importance-" 59 | "weighted bound") 60 | parser.add_argument('--ll-samples', 61 | type=int, 62 | default=100, 63 | dest='ll_samples', 64 | metavar='N', 65 | help="number of importance-weighted samples for " 66 | "log likelihood estimation") 67 | parser.add_argument('--ps', 68 | type=int, 69 | default=1, 70 | dest='prior_samples', 71 | metavar='N', 72 | help="number of batches of samples from prior") 73 | parser.add_argument('--layer-repr', 74 | action='store_true', 75 | dest='inspect_layer_repr', 76 | help='inspect layer representations. Generate ' 77 | 'samples by sampling top layers once, then taking ' 78 | 'many samples from a middle layer, and finally ' 79 | 'sample the downstream layers from the conditional ' 80 | 'mode. Do this for every layer.') 81 | 82 | @classmethod 83 | def _check_args(cls, args: argparse.Namespace) -> argparse.Namespace: 84 | args = super(Evaluator, cls)._check_args(args) 85 | 86 | if not args.ll: 87 | args.ll_samples = 1 88 | if args.load_step is not None: 89 | warnings.warn( 90 | "Loading weights from specific training step is not supported " 91 | "for now. The model will be loaded from the last checkpoint.") 92 | return args 93 | 94 | 95 | def inspect_layer_repr(model, img_folder, n=8): 96 | for i in range(model.n_layers): 97 | 98 | # print('layer', i) 99 | 100 | mode_layers = range(i) 101 | constant_layers = range(i + 1, model.n_layers) 102 | 103 | # Sample top layers once, then take many samples of a middle layer, 104 | # then sample from the mode in all downstream layers. 105 | sample = [] 106 | for r in range(n): 107 | sample.append( 108 | model.sample_prior(n, 109 | mode_layers=mode_layers, 110 | constant_layers=constant_layers)) 111 | sample = torch.cat(sample) 112 | pad_value = img_grid_pad_value(sample) 113 | fname = os.path.join(img_folder, 'sample_mode_layer' + str(i) + '.png') 114 | save_image(sample, fname, nrow=n, pad_value=pad_value) 115 | 116 | 117 | def main(): 118 | evaluator = Evaluator(experiment_class=LVAEExperiment) 119 | evaluator() 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /experiment/__init__.py: -------------------------------------------------------------------------------- 1 | from .experiment_manager import LVAEExperiment 2 | -------------------------------------------------------------------------------- /experiment/data.py: -------------------------------------------------------------------------------- 1 | from multiobject.pytorch import MultiObjectDataset, MultiObjectDataLoader 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | from torchvision.datasets import CIFAR10, SVHN, CelebA 5 | 6 | from lib.datasets import StaticBinaryMnist 7 | 8 | multiobject_paths = { 9 | 'multi_mnist_binary': 10 | './data/multi_mnist/multi_binary_mnist_012.npz', 11 | 'multi_dsprites_binary_rgb': 12 | './data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz', 13 | } 14 | multiobject_datasets = multiobject_paths.keys() 15 | 16 | 17 | class DatasetLoader: 18 | """ 19 | Wrapper for DataLoaders. Data attributes: 20 | - train: DataLoader object for training set 21 | - test: DataLoader object for test set 22 | - data_shape: shape of each data point (channels, height, width) 23 | - img_size: spatial dimensions of each data point (height, width) 24 | - color_ch: number of color channels 25 | """ 26 | 27 | def __init__(self, args, cuda): 28 | 29 | kwargs = {'num_workers': 1, 'pin_memory': False} if cuda else {} 30 | 31 | # Default dataloader class 32 | dataloader_class = DataLoader 33 | 34 | if args.dataset_name == 'static_mnist': 35 | data_folder = './data/static_bin_mnist/' 36 | train_set = StaticBinaryMnist(data_folder, 37 | train=True, 38 | download=True, 39 | shuffle_init=True) 40 | test_set = StaticBinaryMnist(data_folder, 41 | train=False, 42 | download=True, 43 | shuffle_init=True) 44 | 45 | elif args.dataset_name == 'cifar10': 46 | # Discrete values 0, 1/255, ..., 254/255, 1 47 | transform = transforms.Compose([ 48 | # Move values to the center of 256 bins 49 | # transforms.Lambda(lambda x: Image.eval( 50 | # x, lambda y: y * (255/256) + 1/512)), 51 | transforms.ToTensor(), 52 | ]) 53 | data_folder = './data/cifar10/' 54 | train_set = CIFAR10(data_folder, 55 | train=True, 56 | download=True, 57 | transform=transform) 58 | test_set = CIFAR10(data_folder, 59 | train=False, 60 | download=True, 61 | transform=transform) 62 | 63 | elif args.dataset_name == 'svhn': 64 | transform = transforms.ToTensor() 65 | data_folder = './data/svhn/' 66 | train_set = SVHN(data_folder, 67 | split='train', 68 | download=True, 69 | transform=transform) 70 | test_set = SVHN(data_folder, 71 | split='test', 72 | download=True, 73 | transform=transform) 74 | 75 | elif args.dataset_name == 'celeba': 76 | transform = transforms.Compose([ 77 | transforms.CenterCrop(148), 78 | transforms.Resize((64, 64)), 79 | transforms.ToTensor(), 80 | ]) 81 | data_folder = '/scratch/adit/data/celeba/' 82 | train_set = CelebA(data_folder, 83 | split='train', 84 | download=True, 85 | transform=transform) 86 | test_set = CelebA(data_folder, 87 | split='valid', 88 | download=True, 89 | transform=transform) 90 | 91 | elif args.dataset_name in multiobject_datasets: 92 | data_path = multiobject_paths[args.dataset_name] 93 | train_set = MultiObjectDataset(data_path, train=True) 94 | test_set = MultiObjectDataset(data_path, train=False) 95 | 96 | # Custom data loader class 97 | dataloader_class = MultiObjectDataLoader 98 | 99 | else: 100 | raise RuntimeError("Unrecognized data set '{}'".format( 101 | args.dataset_name)) 102 | 103 | self.train = dataloader_class(train_set, 104 | batch_size=args.batch_size, 105 | shuffle=True, 106 | drop_last=True, 107 | **kwargs) 108 | self.test = dataloader_class(test_set, 109 | batch_size=args.test_batch_size, 110 | shuffle=False, 111 | **kwargs) 112 | 113 | self.data_shape = self.train.dataset[0][0].size() 114 | self.img_size = self.data_shape[1:] 115 | self.color_ch = self.data_shape[0] 116 | -------------------------------------------------------------------------------- /experiment/experiment_manager.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import boilr.data 4 | import torch 5 | from boilr import VAEExperimentManager 6 | from boilr.nn.init import data_dependent_init 7 | from boilr.utils import linear_anneal 8 | from torch import optim 9 | from torch.optim.optimizer import Optimizer 10 | from typing import Optional 11 | from models.lvae import LadderVAE 12 | from .data import DatasetLoader 13 | 14 | boilr.set_options(model_print_depth=2) 15 | 16 | 17 | class LVAEExperiment(VAEExperimentManager): 18 | """ 19 | Experiment manager. 20 | 21 | Data attributes: 22 | - 'args': argparse.Namespace containing all config parameters. When 23 | initializing this object, if 'args' is not given, all config 24 | parameters are set based on experiment defaults and user input, using 25 | argparse. 26 | - 'run_description': string description of the run that includes a timestamp 27 | and can be used e.g. as folder name for logging. 28 | - 'model': the model. 29 | - 'device': torch.device that is being used 30 | - 'dataloaders': DataLoaders, with attributes 'train' and 'test' 31 | - 'optimizer': the optimizer 32 | """ 33 | 34 | def _make_datamanager(self) -> boilr.data.BaseDatasetManager: 35 | cuda = self.device.type == 'cuda' 36 | return DatasetLoader(self.args, cuda) 37 | 38 | def _make_model(self) -> torch.nn.Module: 39 | args = self.args 40 | model = LadderVAE( 41 | self.dataloaders.color_ch, 42 | z_dims=args.z_dims, 43 | blocks_per_layer=args.blocks_per_layer, 44 | downsample=args.downsample, 45 | merge_type=args.merge_layers, 46 | batchnorm=args.batch_norm, 47 | nonlin=args.nonlin, 48 | stochastic_skip=args.skip_connections, 49 | n_filters=args.n_filters, 50 | dropout=args.dropout, 51 | res_block_type=args.residual_type, 52 | free_bits=args.free_bits, 53 | learn_top_prior=args.learn_top_prior, 54 | img_shape=self.dataloaders.img_size, 55 | likelihood_form=args.likelihood, 56 | gated=args.gated, 57 | no_initial_downscaling=args.no_initial_downscaling, 58 | analytical_kl=args.analytical_kl, 59 | ).to(self.device) 60 | 61 | # Weight initialization 62 | if args.simple_data_dependent_init: 63 | 64 | # Get batch 65 | t = [ 66 | self.dataloaders.train.dataset[i] 67 | for i in range(args.batch_size) 68 | ] 69 | t = torch.stack(tuple(t[i][0] for i in range(len(t)))) 70 | 71 | # Use batch for data dependent init 72 | data_dependent_init(model, {'x': t.to(self.device)}) 73 | 74 | return model 75 | 76 | def _make_optimizer(self) -> Optimizer: 77 | args = self.args 78 | optimizer = optim.Adamax(self.model.parameters(), 79 | lr=args.lr, 80 | weight_decay=args.weight_decay) 81 | return optimizer 82 | 83 | @classmethod 84 | def _define_args_defaults(cls) -> dict: 85 | defaults = super(LVAEExperiment, cls)._define_args_defaults() 86 | 87 | # Override boilr defaults 88 | defaults.update( 89 | 90 | # General 91 | batch_size=64, 92 | test_batch_size=1000, 93 | lr=3e-4, 94 | train_log_every=10000, 95 | test_log_every=10000, 96 | checkpoint_every=100000, 97 | keep_checkpoint_max=2, 98 | resume="", 99 | 100 | # VI-specific 101 | loglikelihood_every=50000, 102 | loglikelihood_samples=100, 103 | ) 104 | 105 | return defaults 106 | 107 | def _add_args(self, parser: argparse.ArgumentParser) -> None: 108 | 109 | super(LVAEExperiment, self)._add_args(parser) 110 | 111 | def list_options(lst): 112 | if lst: 113 | return "'" + "' | '".join(lst) + "'" 114 | return "" 115 | 116 | legal_merge_layers = ['linear', 'residual'] 117 | legal_nonlin = ['relu', 'leakyrelu', 'elu', 'selu'] 118 | legal_resblock = ['cabdcabd', 'bacdbac', 'bacdbacd'] 119 | legal_datasets = [ 120 | 'static_mnist', 'cifar10', 'celeba', 'svhn', 121 | 'multi_dsprites_binary_rgb', 'multi_mnist_binary' 122 | ] 123 | legal_likelihoods = [ 124 | 'bernoulli', 'gaussian', 'discr_log', 'discr_log_mix' 125 | ] 126 | 127 | parser.add_argument('-d', 128 | '--dataset', 129 | type=str, 130 | choices=legal_datasets, 131 | default='static_mnist', 132 | metavar='NAME', 133 | dest='dataset_name', 134 | help="dataset: " + list_options(legal_datasets)) 135 | 136 | parser.add_argument('--likelihood', 137 | type=str, 138 | choices=legal_likelihoods, 139 | metavar='NAME', 140 | dest='likelihood', 141 | help="likelihood: {}; the default depends on the " 142 | "dataset".format(list_options(legal_likelihoods))) 143 | 144 | parser.add_argument('--zdims', 145 | nargs='+', 146 | type=int, 147 | default=[32, 32, 32], 148 | metavar='DIM', 149 | dest='z_dims', 150 | help='list of dimensions (number of channels) for ' 151 | 'each stochastic layer') 152 | 153 | parser.add_argument('--blocks-per-layer', 154 | type=int, 155 | default=2, 156 | metavar='N', 157 | help='residual blocks between stochastic layers') 158 | 159 | parser.add_argument('--nfilters', 160 | type=int, 161 | default=64, 162 | metavar='N', 163 | dest='n_filters', 164 | help='number of channels in all residual blocks') 165 | 166 | parser.add_argument('--no-bn', 167 | action='store_true', 168 | dest='no_batch_norm', 169 | help='do not use batch normalization') 170 | 171 | parser.add_argument('--skip', 172 | action='store_true', 173 | dest='skip_connections', 174 | help='skip connections in generative model') 175 | 176 | parser.add_argument('--gated', 177 | action='store_true', 178 | dest='gated', 179 | help='use gated layers in residual blocks') 180 | 181 | parser.add_argument('--downsample', 182 | nargs='+', 183 | type=int, 184 | default=[1, 1, 1], 185 | metavar='N', 186 | help='list of integers, each int is the number of ' 187 | 'downsampling steps (by a factor of 2) before each ' 188 | 'stochastic layer') 189 | 190 | parser.add_argument('--learn-top-prior', 191 | action='store_true', 192 | help="learn the top-layer prior") 193 | 194 | parser.add_argument('--residual-type', 195 | type=str, 196 | choices=legal_resblock, 197 | default='bacdbacd', 198 | metavar='TYPE', 199 | help="type of residual blocks: " + 200 | list_options(legal_resblock)) 201 | 202 | parser.add_argument('--merge-layers', 203 | type=str, 204 | choices=legal_merge_layers, 205 | default='residual', 206 | metavar='TYPE', 207 | help="type of merge layers: " + 208 | list_options(legal_merge_layers)) 209 | 210 | parser.add_argument('--beta-anneal', 211 | type=int, 212 | default=0, 213 | metavar='B', 214 | help='steps for annealing beta from 0 to 1') 215 | 216 | parser.add_argument('--data-dep-init', 217 | action='store_true', 218 | dest='simple_data_dependent_init', 219 | help='use simple data-dependent initialization to ' 220 | 'normalize outputs of affine layers') 221 | 222 | parser.add_argument('--wd', 223 | type=float, 224 | default=0.0, 225 | dest='weight_decay', 226 | help='weight decay') 227 | 228 | parser.add_argument('--nonlin', 229 | type=str, 230 | choices=legal_nonlin, 231 | default='elu', 232 | metavar='F', 233 | help="nonlinear activation: " + 234 | list_options(legal_nonlin)) 235 | 236 | parser.add_argument('--dropout', 237 | type=float, 238 | default=0.2, 239 | metavar='D', 240 | help='dropout probability (in deterministic ' 241 | 'layers)') 242 | 243 | parser.add_argument('--freebits', 244 | type=float, 245 | default=0.0, 246 | metavar='N', 247 | dest='free_bits', 248 | help='free bits (nats)') 249 | 250 | parser.add_argument('--analytical-kl', 251 | action='store_true', 252 | dest='analytical_kl', 253 | help='use analytical KL') 254 | 255 | parser.add_argument('--no-initial-downscaling', 256 | action='store_true', 257 | dest='no_initial_downscaling', 258 | help='do not downscale as first inference step (and' 259 | 'upscale as last generation step)') 260 | 261 | @classmethod 262 | def _check_args(cls, args: argparse.Namespace) -> argparse.Namespace: 263 | 264 | args = super(LVAEExperiment, cls)._check_args(args) 265 | 266 | if len(args.z_dims) != len(args.downsample): 267 | msg = ("length of list of latent dimensions ({}) does not match " 268 | "length of list of downsampling factors ({})").format( 269 | len(args.z_dims), len(args.downsample)) 270 | raise RuntimeError(msg) 271 | 272 | assert args.weight_decay >= 0.0 273 | assert 0.0 <= args.dropout <= 1.0 274 | if args.dropout < 1e-5: 275 | args.dropout = None 276 | assert args.free_bits >= 0.0 277 | args.batch_norm = not args.no_batch_norm 278 | 279 | likelihood_map = { 280 | 'static_mnist': 'bernoulli', 281 | 'multi_dsprites_binary_rgb': 'bernoulli', 282 | 'multi_mnist_binary': 'bernoulli', 283 | 'cifar10': 'discr_log_mix', 284 | 'celeba': 'discr_log_mix', 285 | 'svhn': 'discr_log_mix', 286 | } 287 | if args.likelihood is None: # default 288 | args.likelihood = likelihood_map[args.dataset_name] 289 | 290 | return args 291 | 292 | @staticmethod 293 | def _make_run_description(args: argparse.Namespace) -> str: 294 | s = '' 295 | s += args.dataset_name 296 | s += ',{}ly'.format(len(args.z_dims)) 297 | # s += ',z=' + str(args.z_dims).replace(" ", "") 298 | # s += ',dwn=' + str(args.downsample).replace(" ", "") 299 | s += ',{}bpl'.format(args.blocks_per_layer) 300 | s += ',{}ch'.format(args.n_filters) 301 | if args.skip_connections: 302 | s += ',skip' 303 | if args.gated: 304 | s += ',gate' 305 | s += ',block=' + args.residual_type 306 | if args.beta_anneal != 0: 307 | s += ',b{}'.format(args.beta_anneal) 308 | s += ',{}'.format(args.nonlin) 309 | if args.free_bits > 0: 310 | s += ',freeb={}'.format(args.free_bits) 311 | if args.dropout is not None: 312 | s += ',drop={}'.format(args.dropout) 313 | if args.learn_top_prior: 314 | s += ',learnp' 315 | if args.weight_decay > 0.0: 316 | s += ',wd={}'.format(args.weight_decay) 317 | s += ',seed{}'.format(args.seed) 318 | if len(args.additional_descr) > 0: 319 | s += ',' + args.additional_descr 320 | return s 321 | 322 | def forward_pass(self, 323 | x: torch.Tensor, 324 | y: Optional[torch.Tensor] = None) -> dict: 325 | 326 | # Forward pass 327 | x = x.to(self.device, non_blocking=True) 328 | model_out = self.model(x) 329 | recons_sep = -model_out['ll'] 330 | kl_sep = model_out['kl_sep'] 331 | kl = model_out['kl'] 332 | kl_loss = model_out['kl_loss'] 333 | 334 | # ELBO 335 | elbo_sep = -(recons_sep + kl_sep) 336 | elbo = elbo_sep.mean() 337 | 338 | # Loss with beta 339 | beta = 1. 340 | if self.args.beta_anneal != 0: 341 | beta = linear_anneal(self.model.global_step, 0.0, 1.0, 342 | self.args.beta_anneal) 343 | recons = recons_sep.mean() 344 | loss = recons + kl_loss * beta 345 | 346 | # L2 347 | l2 = 0.0 348 | for p in self.model.parameters(): 349 | l2 = l2 + torch.sum(p**2) 350 | l2 = l2.sqrt() 351 | 352 | output = { 353 | 'loss': loss, 354 | 'elbo': elbo, 355 | 'elbo_sep': elbo_sep, 356 | 'kl': kl, 357 | 'l2': l2, 358 | 'recons': recons, 359 | 'out_mean': model_out['out_mean'], 360 | 'out_mode': model_out['out_mode'], 361 | 'out_sample': model_out['out_sample'], 362 | 'likelihood_params': model_out['likelihood_params'], 363 | } 364 | if 'kl_avg_layerwise' in model_out: 365 | output['kl_avg_layerwise'] = model_out['kl_avg_layerwise'] 366 | 367 | return output 368 | 369 | @classmethod 370 | def train_log_str(cls, 371 | summaries: dict, 372 | step: int, 373 | epoch: Optional[int] = None) -> str: 374 | s = " [step {}] loss: {:.5g} ELBO: {:.5g} recons: {:.3g} KL: {:.3g}" 375 | s = s.format(step, summaries['loss/loss'], summaries['elbo/elbo'], 376 | summaries['elbo/recons'], summaries['elbo/kl']) 377 | return s 378 | 379 | @classmethod 380 | def test_log_str(cls, 381 | summaries: dict, 382 | step: int, 383 | epoch: Optional[int] = None) -> str: 384 | s = " " 385 | if epoch is not None: 386 | s += "[step {}, epoch {}] ".format(step, epoch) 387 | s += "ELBO {:.5g} recons: {:.3g} KL: {:.3g}".format( 388 | summaries['elbo/elbo'], summaries['elbo/recons'], 389 | summaries['elbo/kl']) 390 | ll_key = None 391 | for k in summaries.keys(): 392 | if k.find('elbo_IW') > -1: 393 | ll_key = k 394 | iw_samples = k.split('_')[-1] 395 | break 396 | if ll_key is not None: 397 | s += " marginal log-likelihood ({}) {:.5g}".format( 398 | iw_samples, summaries[ll_key]) 399 | 400 | return s 401 | 402 | @classmethod 403 | def get_metrics_dict(cls, results: dict) -> dict: 404 | metrics_dict = { 405 | 'loss/loss': results['loss'].item(), 406 | 'elbo/elbo': results['elbo'].item(), 407 | 'elbo/recons': results['recons'].item(), 408 | 'elbo/kl': results['kl'].item(), 409 | 'l2/l2': results['l2'].item(), 410 | } 411 | if 'kl_avg_layerwise' in results: 412 | for i in range(len(results['kl_avg_layerwise'])): 413 | key = 'kl_layers/kl_layer_{}'.format(i) 414 | metrics_dict[key] = results['kl_avg_layerwise'][i].item() 415 | return metrics_dict 416 | -------------------------------------------------------------------------------- /lib/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from urllib import request 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import TensorDataset 7 | 8 | 9 | class StaticBinaryMnist(TensorDataset): 10 | 11 | def __init__(self, folder, train, download=False, shuffle_init=False): 12 | self.download = download 13 | if train: 14 | sets = [ 15 | self._get_binarized_mnist(folder, shuffle_init, split='train'), 16 | self._get_binarized_mnist(folder, shuffle_init, split='valid') 17 | ] 18 | x = np.concatenate(sets, axis=0) 19 | else: 20 | x = self._get_binarized_mnist(folder, shuffle_init, split='test') 21 | labels = torch.zeros(len(x),).fill_(float('nan')) 22 | super().__init__(torch.from_numpy(x), labels) 23 | 24 | def _get_binarized_mnist(self, folder, shuffle_init, split=None): 25 | """ 26 | Get statically binarized MNIST. Code partially taken from 27 | https://github.com/altosaar/proximity_vi/blob/master/get_binary_mnist.py 28 | """ 29 | 30 | subdatasets = ['train', 'valid', 'test'] 31 | if split not in subdatasets: 32 | raise ValueError("Valid splits: {}".format(subdatasets)) 33 | data = {} 34 | 35 | fname = 'binarized_mnist_{}.npz'.format(split) 36 | path = os.path.join(folder, fname) 37 | 38 | if not os.path.exists(path): 39 | print("Dataset file '{}' not found".format(path)) 40 | if not self.download: 41 | msg = "Dataset not found, use download=True to download it" 42 | raise RuntimeError(msg) 43 | 44 | print("Downloading whole dataset...") 45 | 46 | os.makedirs(folder, exist_ok=True) 47 | 48 | for subdataset in subdatasets: 49 | fname_mat = 'binarized_mnist_{}.amat'.format(subdataset) 50 | url = ('http://www.cs.toronto.edu/~larocheh/public/datasets/' 51 | 'binarized_mnist/{}'.format(fname_mat)) 52 | path_mat = os.path.join(folder, fname_mat) 53 | request.urlretrieve(url, path_mat) 54 | 55 | with open(path_mat) as f: 56 | lines = f.readlines() 57 | 58 | os.remove(path_mat) 59 | lines = np.array( 60 | [[int(i) for i in line.split()] for line in lines]) 61 | data[subdataset] = lines.astype('float32').reshape( 62 | (-1, 1, 28, 28)) 63 | np.savez_compressed(path_mat.split(".amat")[0], 64 | data=data[subdataset]) 65 | 66 | else: 67 | data[split] = np.load(path)['data'] 68 | 69 | if shuffle_init: 70 | np.random.shuffle(data[split]) 71 | 72 | return data[split] 73 | 74 | 75 | def _pad_tensor(x, size, value=None): 76 | assert isinstance(x, torch.Tensor) 77 | input_size = len(x) 78 | if value is None: 79 | value = float('nan') 80 | 81 | # Copy input tensor into a tensor filled with specified value 82 | # Convert everything to float, not ideal but it's robust 83 | out = torch.zeros(*size, dtype=torch.float) 84 | out.fill_(value) 85 | if input_size > 0: # only if at least one element in the sequence 86 | out[:input_size] = x.float() 87 | return out 88 | -------------------------------------------------------------------------------- /lib/likelihoods.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torch.distributions import Normal 8 | from torch.nn import functional as F 9 | 10 | from .stochastic import logistic_rsample, sample_from_discretized_mix_logistic 11 | 12 | 13 | class LikelihoodModule(nn.Module): 14 | 15 | def distr_params(self, x): 16 | pass 17 | 18 | @staticmethod 19 | def mean(params): 20 | pass 21 | 22 | @staticmethod 23 | def mode(params): 24 | pass 25 | 26 | @staticmethod 27 | def sample(params): 28 | pass 29 | 30 | def log_likelihood(self, x, params): 31 | pass 32 | 33 | def forward(self, input_, x): 34 | distr_params = self.distr_params(input_) 35 | mean = self.mean(distr_params) 36 | mode = self.mode(distr_params) 37 | sample = self.sample(distr_params) 38 | if x is None: 39 | ll = None 40 | else: 41 | ll = self.log_likelihood(x, distr_params) 42 | dct = { 43 | 'mean': mean, 44 | 'mode': mode, 45 | 'sample': sample, 46 | 'params': distr_params, 47 | } 48 | return ll, dct 49 | 50 | 51 | class BernoulliLikelihood(LikelihoodModule): 52 | 53 | def __init__(self, ch_in, color_channels): 54 | super().__init__() 55 | self.parameter_net = nn.Conv2d(ch_in, 56 | color_channels, 57 | kernel_size=3, 58 | padding=1) 59 | 60 | def distr_params(self, x): 61 | x = self.parameter_net(x) 62 | x = torch.sigmoid(x) 63 | return x 64 | 65 | @staticmethod 66 | def mean(params): 67 | return params 68 | 69 | @staticmethod 70 | def mode(params): 71 | return torch.round(params) 72 | 73 | @staticmethod 74 | def sample(params): 75 | return (torch.rand_like(params) < params).float() 76 | 77 | def log_likelihood(self, x, params): 78 | return log_bernoulli(x, params, reduce='none') 79 | 80 | 81 | class GaussianLikelihood(LikelihoodModule): 82 | 83 | def __init__(self, ch_in, color_channels): 84 | super().__init__() 85 | self.parameter_net = nn.Conv2d(ch_in, 86 | 2 * color_channels, 87 | kernel_size=3, 88 | padding=1) 89 | 90 | def distr_params(self, x): 91 | x = self.parameter_net(x) 92 | mean, lv = x.chunk(2, dim=1) 93 | params = { 94 | 'mean': mean, 95 | 'logvar': lv, 96 | } 97 | return params 98 | 99 | @staticmethod 100 | def mean(params): 101 | return params['mean'] 102 | 103 | @staticmethod 104 | def mode(params): 105 | return params['mean'] 106 | 107 | @staticmethod 108 | def sample(params): 109 | p = Normal(params['mean'], (params['logvar'] / 2).exp()) 110 | return p.rsample() 111 | 112 | def log_likelihood(self, x, params): 113 | logprob = log_normal(x, params['mean'], params['logvar'], reduce='none') 114 | return logprob 115 | 116 | 117 | class DiscretizedLogisticLikelihood(LikelihoodModule): 118 | """ 119 | Assume input data to be originally uint8 (0, ..., 255) and then rescaled 120 | by 1/255: discrete values in {0, 1/255, ..., 255/255}. 121 | If using the discretize logistic logprob implementation here, this should 122 | be rescaled by 255/256 and shifted by <1/256 in this class. So the data is 123 | inside 256 bins between 0 and 1. 124 | 125 | Note that mean and logscale are parameters of the underlying continuous 126 | logistic distribution, not of its discretization. 127 | """ 128 | 129 | log_scale_bias = -1. 130 | 131 | def __init__(self, ch_in, color_channels, n_bins, double=False): 132 | super().__init__() 133 | self.n_bins = n_bins 134 | self.double_precision = double 135 | self.parameter_net = nn.Conv2d(ch_in, 136 | 2 * color_channels, 137 | kernel_size=3, 138 | padding=1) 139 | 140 | def distr_params(self, x): 141 | x = self.parameter_net(x) 142 | mean, ls = x.chunk(2, dim=1) 143 | ls = ls + self.log_scale_bias 144 | ls = ls.clamp(min=-7.) 145 | mean = mean + 0.5 # initialize to mid interval 146 | params = { 147 | 'mean': mean, 148 | 'logscale': ls, 149 | } 150 | return params 151 | 152 | @staticmethod 153 | def mean(params): 154 | return params['mean'] 155 | 156 | @staticmethod 157 | def mode(params): 158 | return params['mean'] 159 | 160 | @staticmethod 161 | def sample(params): 162 | # We're not quantizing 8bit, but it doesn't matter 163 | sample = logistic_rsample((params['mean'], params['logscale'])) 164 | sample = sample.clamp(min=0., max=1.) 165 | return sample 166 | 167 | def log_likelihood(self, x, params): 168 | # Input data x should be inside (not at the edge) n_bins equally-sized 169 | # bins between 0 and 1. E.g. if n_bins=256 the 257 bin edges are: 170 | # 0, 1/256, ..., 255/256, 1. 171 | 172 | x = x * (255 / 256) + 1 / 512 173 | 174 | logprob = log_discretized_logistic(x, 175 | params['mean'], 176 | params['logscale'], 177 | n_bins=self.n_bins, 178 | reduce='none', 179 | double=self.double_precision) 180 | return logprob 181 | 182 | 183 | class DiscretizedLogisticMixLikelihood(LikelihoodModule): 184 | """ 185 | Sampling and loss computation are based on the original tf code. 186 | 187 | Assume input data to be originally uint8 (0, ..., 255) and then rescaled 188 | by 1/255: discrete values in {0, 1/255, ..., 255/255}. 189 | When using the original discretize logistic mixture logprob implementation, 190 | this data should be rescaled to be in [-1, 1]. 191 | 192 | Mean and mode are not implemented for now. 193 | 194 | Color channels for now is fixed to 3 and n_bins to 256. 195 | """ 196 | 197 | def __init__(self, ch_in, n_components=10): 198 | super().__init__() 199 | self.parameter_net = nn.Conv2d(ch_in, 200 | 10 * n_components, 201 | kernel_size=3, 202 | padding=1) 203 | 204 | def distr_params(self, x): 205 | x = self.parameter_net(x) 206 | params = { 207 | 'mean': None, # TODO 208 | 'all_params': x 209 | } 210 | return params 211 | 212 | @staticmethod 213 | def mean(params): 214 | return params['mean'] 215 | 216 | @staticmethod 217 | def mode(params): 218 | return params['mean'] 219 | 220 | @staticmethod 221 | def sample(params): 222 | sample = sample_from_discretized_mix_logistic(params['all_params']) 223 | sample = (sample + 1) / 2 224 | sample = sample.clamp(min=0., max=1.) 225 | return sample 226 | 227 | def log_likelihood(self, x, params): 228 | x = x * 2 - 1 229 | logprob = -discretized_mix_logistic_loss(x, params['all_params']) 230 | return logprob 231 | 232 | 233 | def log_discretized_logistic(x, 234 | mean, 235 | log_scale, 236 | n_bins=256, 237 | reduce='mean', 238 | double=False): 239 | """ 240 | Log of the probability mass of the values x under the logistic distribution 241 | with parameters mean and scale. The sum is taken over all dimensions except 242 | for the first one (assumed to be batch). Reduction is applied at the end. 243 | 244 | Assume input data to be inside (not at the edge) of n_bins equally-sized 245 | bins between 0 and 1. E.g. if n_bins=256 the 257 bin edges are: 246 | 0, 1/256, ..., 255/256, 1. 247 | If values are at the left edge it's also ok, but let's be on the safe side 248 | 249 | Variance of logistic distribution is 250 | var = scale^2 * pi^2 / 3 251 | 252 | :param x: tensor with shape (batch, channels, dim1, dim2) 253 | :param mean: tensor with mean of distribution, shape 254 | (batch, channels, dim1, dim2) 255 | :param log_scale: tensor with log scale of distribution, shape has to be either 256 | scalar or broadcastable 257 | :param n_bins: bin size (default: 256) 258 | :param reduce: reduction over batch: 'mean' | 'sum' | 'none' 259 | :param double: whether double precision should be used for computations 260 | :return: 261 | """ 262 | log_scale = _input_check(x, mean, log_scale, reduce) 263 | if double: 264 | log_scale = log_scale.double() 265 | x = x.double() 266 | mean = mean.double() 267 | eps = 1e-14 268 | else: 269 | eps = 1e-7 270 | # scale = np.sqrt(3) * torch.exp(logvar / 2) / np.pi 271 | scale = log_scale.exp() 272 | 273 | # Set values to the left of each bin 274 | x = torch.floor(x * n_bins) / n_bins 275 | 276 | cdf_plus = torch.ones_like(x) 277 | idx = x < (n_bins - 1) / n_bins 278 | cdf_plus[idx] = torch.sigmoid( 279 | (x[idx] + 1 / n_bins - mean[idx]) / scale[idx]) 280 | cdf_minus = torch.zeros_like(x) 281 | idx = x >= 1 / n_bins 282 | cdf_minus[idx] = torch.sigmoid((x[idx] - mean[idx]) / scale[idx]) 283 | log_prob = torch.log(cdf_plus - cdf_minus + eps) 284 | log_prob = log_prob.sum((1, 2, 3)) 285 | log_prob = _reduce(log_prob, reduce) 286 | if double: 287 | log_prob = log_prob.float() 288 | return log_prob 289 | 290 | 291 | def discretized_mix_logistic_loss(x, l): 292 | """ 293 | log-likelihood for mixture of discretized logistics, assumes the data 294 | has been rescaled to [-1,1] interval 295 | 296 | Code taken from pytorch adaptation of original PixelCNN++ tf implementation 297 | https://github.com/pclucas14/pixel-cnn-pp 298 | """ 299 | 300 | # channels last 301 | x = x.permute(0, 2, 3, 1) 302 | l = l.permute(0, 2, 3, 1) 303 | 304 | # true image (i.e. labels) to regress to, e.g. (B,32,32,3) 305 | xs = [int(y) for y in x.size()] 306 | # predicted distribution, e.g. (B,32,32,100) 307 | ls = [int(y) for y in l.size()] 308 | 309 | # here and below: unpacking the params of the mixture of logistics 310 | nr_mix = int(ls[-1] / 10) 311 | logit_probs = l[:, :, :, :nr_mix] 312 | l = l[:, :, :, nr_mix:].contiguous().view( 313 | xs + [nr_mix * 3]) # 3 for mean, scale, coef 314 | means = l[:, :, :, :, :nr_mix] 315 | # log_scales = torch.max(l[:, :, :, :, nr_mix:2 * nr_mix], -7.) 316 | log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.) 317 | 318 | coeffs = torch.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) 319 | # here and below: getting the means and adjusting them based on preceding 320 | # sub-pixels 321 | x = x.contiguous() 322 | x = x.unsqueeze(-1) + nn.Parameter(torch.zeros(xs + [nr_mix]).to(x.device), 323 | requires_grad=False) 324 | m2 = (means[:, :, :, 1, :] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :]).view( 325 | xs[0], xs[1], xs[2], 1, nr_mix) 326 | 327 | m3 = (means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + 328 | coeffs[:, :, :, 2, :] * x[:, :, :, 1, :]).view( 329 | xs[0], xs[1], xs[2], 1, nr_mix) 330 | 331 | means = torch.cat((means[:, :, :, 0, :].unsqueeze(3), m2, m3), dim=3) 332 | centered_x = x - means 333 | inv_stdv = torch.exp(-log_scales) 334 | plus_in = inv_stdv * (centered_x + 1. / 255.) 335 | cdf_plus = torch.sigmoid(plus_in) 336 | min_in = inv_stdv * (centered_x - 1. / 255.) 337 | cdf_min = torch.sigmoid(min_in) 338 | # log probability for edge case of 0 (before scaling) 339 | log_cdf_plus = plus_in - F.softplus(plus_in) 340 | # log probability for edge case of 255 (before scaling) 341 | log_one_minus_cdf_min = -F.softplus(min_in) 342 | cdf_delta = cdf_plus - cdf_min # probability for all other cases 343 | mid_in = inv_stdv * centered_x 344 | # log probability in the center of the bin, to be used in extreme cases 345 | # (not actually used in our code) 346 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) 347 | 348 | # now select the right output: left edge case, right edge case, normal 349 | # case, extremely low prob case (doesn't actually happen for us) 350 | 351 | # this is what we are really doing, but using the robust version below 352 | # for extreme cases in other applications and to avoid NaN issue with tf.select() 353 | # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, 354 | # log_one_minus_cdf_min, tf.log(cdf_delta))) 355 | 356 | # robust version, that still works if probabilities are below 1e-5 (which 357 | # never happens in our code) 358 | # tensorflow backpropagates through tf.select() by multiplying with zero 359 | # instead of selecting: this requires use to use some ugly tricks to avoid 360 | # potential NaNs 361 | # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as 362 | # output, it's purely there to get around the tf.select() gradient issue 363 | # if the probability on a sub-pixel is below 1e-5, we use an approximation 364 | # based on the assumption that the log-density is constant in the bin of 365 | # the observed sub-pixel value 366 | 367 | inner_inner_cond = (cdf_delta > 1e-5).float() 368 | inner_inner_out = inner_inner_cond * torch.log( 369 | torch.clamp(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * ( 370 | log_pdf_mid - np.log(127.5)) 371 | inner_cond = (x > 0.999).float() 372 | inner_out = inner_cond * log_one_minus_cdf_min + ( 373 | 1. - inner_cond) * inner_inner_out 374 | cond = (x < -0.999).float() 375 | log_probs = cond * log_cdf_plus + (1. - cond) * inner_out 376 | log_probs = torch.sum(log_probs, dim=3) + torch.log_softmax(logit_probs, 377 | dim=-1) 378 | log_probs = torch.logsumexp(log_probs, dim=-1) 379 | 380 | # return -torch.sum(log_probs) 381 | loss_sep = -log_probs.sum((1, 2)) # keep batch dimension 382 | return loss_sep 383 | 384 | 385 | def log_bernoulli(x, mean, reduce='mean'): 386 | log_prob = -F.binary_cross_entropy(mean, x, reduction='none') 387 | log_prob = log_prob.sum((1, 2, 3)) 388 | return _reduce(log_prob, reduce) 389 | 390 | 391 | def log_normal(x, mean, logvar, reduce='mean'): 392 | """ 393 | Log of the probability density of the values x untder the Normal 394 | distribution with parameters mean and logvar. The sum is taken over all 395 | dimensions except for the first one (assumed to be batch). Reduction 396 | is applied at the end. 397 | :param x: tensor of points, with shape (batch, channels, dim1, dim2) 398 | :param mean: tensor with mean of distribution, shape 399 | (batch, channels, dim1, dim2) 400 | :param logvar: tensor with log-variance of distribution, shape has to be 401 | either scalar or broadcastable 402 | :param reduce: reduction over batch: 'mean' | 'sum' | 'none' 403 | :return: 404 | """ 405 | 406 | logvar = _input_check(x, mean, logvar, reduce) 407 | var = torch.exp(logvar) 408 | log_prob = -0.5 * (( 409 | (x - mean)**2) / var + logvar + torch.tensor(2 * math.pi).log()) 410 | log_prob = log_prob.sum((1, 2, 3)) 411 | return _reduce(log_prob, reduce) 412 | 413 | 414 | def _reduce(x, reduce): 415 | if reduce == 'mean': 416 | x = x.mean() 417 | elif reduce == 'sum': 418 | x = x.sum() 419 | return x 420 | 421 | 422 | def _input_check(x, mean, scale_param, reduce): 423 | assert x.dim() == 4 424 | assert x.size() == mean.size() 425 | if scale_param.numel() == 1: 426 | scale_param = scale_param.view(1, 1, 1, 1) 427 | if reduce not in ['mean', 'sum', 'none']: 428 | msg = "unrecognized reduction method '{}'".format(reduce) 429 | raise RuntimeError(msg) 430 | return scale_param 431 | 432 | 433 | if __name__ == '__main__': 434 | 435 | import seaborn as sns 436 | sns.set() 437 | 438 | # *** Test discretized logistic likelihood and plot examples 439 | 440 | # Fix predicted distribution, change true data from 0 to 1: 441 | # show log probability of given distribution on the range [0, 1] 442 | t = torch.arange(0., 1., 1 / 10000).view(-1, 1, 1, 1) 443 | mean_ = torch.zeros_like(t) + 0.3 444 | logscales = np.arange(-7., 0., 1.) 445 | plt.figure(figsize=(15, 8)) 446 | for logscale_ in logscales: 447 | logscale = torch.tensor(logscale_).expand_as(t) 448 | log_prob = log_discretized_logistic(t, 449 | mean_, 450 | logscale, 451 | n_bins=256, 452 | reduce='none', 453 | double=True) 454 | plt.plot(t.flatten().numpy(), 455 | log_prob.numpy(), 456 | label='logscale={}'.format(logscale_)) 457 | plt.xlabel('data (x)') 458 | plt.ylabel('logprob') 459 | plt.title('log DiscrLogistic(x | 0.3, scale)') 460 | plt.legend() 461 | plt.show() 462 | 463 | # Fix true data, change distribution: 464 | # show log probability of fixed data under different distributions 465 | logscales = np.arange(-7., 0., 1.) 466 | mean_ = torch.arange(0., 1., 1 / 10000).view(-1, 1, 1, 1) 467 | t = torch.tensor(0.3).expand_as(mean_) 468 | plt.figure(figsize=(15, 8)) 469 | for logscale_ in logscales: 470 | logscale = torch.tensor(logscale_).expand_as(mean_) 471 | log_prob = log_discretized_logistic(t, 472 | mean_, 473 | logscale, 474 | n_bins=256, 475 | reduce='none', 476 | double=True) 477 | plt.plot(mean_.flatten().numpy(), 478 | log_prob.numpy(), 479 | label='logscale={}'.format(logscale_)) 480 | plt.xlabel('mean of logistic') 481 | plt.ylabel('logprob') 482 | plt.title('log DiscrLogistic(0.3 | mean, scale)') 483 | plt.legend() 484 | plt.show() 485 | -------------------------------------------------------------------------------- /lib/nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ResidualBlock(nn.Module): 6 | """ 7 | Residual block with 2 convolutional layers. 8 | Input, intermediate, and output channels are the same. Padding is always 9 | 'same'. The 2 convolutional layers have the same groups. No stride allowed, 10 | and kernel sizes have to be odd. 11 | 12 | The result is: 13 | out = gate(f(x)) + x 14 | where an argument controls the presence of the gating mechanism, and f(x) 15 | has different structures depending on the argument block_type. 16 | block_type is a string specifying the structure of the block, where: 17 | a = activation 18 | b = batch norm 19 | c = conv layer 20 | d = dropout. 21 | For example, bacdbacd has 2x (batchnorm, activation, conv, dropout). 22 | """ 23 | 24 | default_kernel_size = (3, 3) 25 | 26 | def __init__(self, 27 | channels, 28 | nonlin, 29 | kernel=None, 30 | groups=1, 31 | batchnorm=True, 32 | block_type=None, 33 | dropout=None, 34 | gated=None): 35 | super().__init__() 36 | if kernel is None: 37 | kernel = self.default_kernel_size 38 | elif isinstance(kernel, int): 39 | kernel = (kernel, kernel) 40 | elif len(kernel) != 2: 41 | raise ValueError( 42 | "kernel has to be None, int, or an iterable of length 2") 43 | assert all([k % 2 == 1 for k in kernel]), "kernel sizes have to be odd" 44 | kernel = list(kernel) 45 | pad = [k // 2 for k in kernel] 46 | self.gated = gated 47 | 48 | modules = [] 49 | 50 | if block_type == 'cabdcabd': 51 | for i in range(2): 52 | conv = nn.Conv2d(channels, 53 | channels, 54 | kernel[i], 55 | padding=pad[i], 56 | groups=groups) 57 | modules.append(conv) 58 | modules.append(nonlin()) 59 | if batchnorm: 60 | modules.append(nn.BatchNorm2d(channels)) 61 | if dropout is not None: 62 | modules.append(nn.Dropout2d(dropout)) 63 | 64 | elif block_type == 'bacdbac': 65 | for i in range(2): 66 | if batchnorm: 67 | modules.append(nn.BatchNorm2d(channels)) 68 | modules.append(nonlin()) 69 | conv = nn.Conv2d(channels, 70 | channels, 71 | kernel[i], 72 | padding=pad[i], 73 | groups=groups) 74 | modules.append(conv) 75 | if dropout is not None and i == 0: 76 | modules.append(nn.Dropout2d(dropout)) 77 | 78 | elif block_type == 'bacdbacd': 79 | for i in range(2): 80 | if batchnorm: 81 | modules.append(nn.BatchNorm2d(channels)) 82 | modules.append(nonlin()) 83 | conv = nn.Conv2d(channels, 84 | channels, 85 | kernel[i], 86 | padding=pad[i], 87 | groups=groups) 88 | modules.append(conv) 89 | modules.append(nn.Dropout2d(dropout)) 90 | 91 | else: 92 | raise ValueError("unrecognized block type '{}'".format(block_type)) 93 | 94 | if gated: 95 | modules.append(GateLayer2d(channels, 1, nonlin)) 96 | self.block = nn.Sequential(*modules) 97 | 98 | def forward(self, x): 99 | return self.block(x) + x 100 | 101 | 102 | class ResidualGatedBlock(ResidualBlock): 103 | 104 | def __init__(self, *args, **kwargs): 105 | super().__init__(*args, **kwargs, gated=True) 106 | 107 | 108 | class GateLayer2d(nn.Module): 109 | """ 110 | Double the number of channels through a convolutional layer, then use 111 | half the channels as gate for the other half. 112 | """ 113 | 114 | def __init__(self, channels, kernel_size, nonlin=nn.LeakyReLU): 115 | super().__init__() 116 | assert kernel_size % 2 == 1 117 | pad = kernel_size // 2 118 | self.conv = nn.Conv2d(channels, 2 * channels, kernel_size, padding=pad) 119 | self.nonlin = nonlin() 120 | 121 | def forward(self, x): 122 | x = self.conv(x) 123 | x, gate = torch.chunk(x, 2, dim=1) 124 | x = self.nonlin(x) # TODO remove this? 125 | gate = torch.sigmoid(gate) 126 | return x * gate 127 | -------------------------------------------------------------------------------- /lib/stochastic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.distributions import kl_divergence 4 | from torch.distributions.normal import Normal 5 | 6 | 7 | class NormalStochasticBlock2d(nn.Module): 8 | """ 9 | Transform input parameters to q(z) with a convolution, optionally do the 10 | same for p(z), then sample z ~ q(z) and return conv(z). 11 | 12 | If q's parameters are not given, do the same but sample from p(z). 13 | """ 14 | 15 | def __init__(self, c_in, c_vars, c_out, kernel=3, transform_p_params=True): 16 | super().__init__() 17 | assert kernel % 2 == 1 18 | pad = kernel // 2 19 | self.transform_p_params = transform_p_params 20 | self.c_in = c_in 21 | self.c_out = c_out 22 | self.c_vars = c_vars 23 | 24 | if transform_p_params: 25 | self.conv_in_p = nn.Conv2d(c_in, 2 * c_vars, kernel, padding=pad) 26 | self.conv_in_q = nn.Conv2d(c_in, 2 * c_vars, kernel, padding=pad) 27 | self.conv_out = nn.Conv2d(c_vars, c_out, kernel, padding=pad) 28 | 29 | def forward(self, 30 | p_params, 31 | q_params=None, 32 | forced_latent=None, 33 | use_mode=False, 34 | force_constant_output=False, 35 | analytical_kl=False): 36 | 37 | assert (forced_latent is None) or (not use_mode) 38 | 39 | if self.transform_p_params: 40 | p_params = self.conv_in_p(p_params) 41 | else: 42 | assert p_params.size(1) == 2 * self.c_vars 43 | 44 | # Define p(z) 45 | p_mu, p_lv = p_params.chunk(2, dim=1) 46 | p = Normal(p_mu, (p_lv / 2).exp()) 47 | 48 | if q_params is not None: 49 | # Define q(z) 50 | q_params = self.conv_in_q(q_params) 51 | q_mu, q_lv = q_params.chunk(2, dim=1) 52 | q = Normal(q_mu, (q_lv / 2).exp()) 53 | 54 | # Sample from q(z) 55 | sampling_distrib = q 56 | else: 57 | # Sample from p(z) 58 | sampling_distrib = p 59 | 60 | # Generate latent variable (typically by sampling) 61 | if forced_latent is None: 62 | if use_mode: 63 | z = sampling_distrib.mean 64 | else: 65 | z = sampling_distrib.rsample() 66 | else: 67 | z = forced_latent 68 | 69 | # Copy one sample (and distrib parameters) over the whole batch. 70 | # This is used when doing experiment from the prior - q is not used. 71 | if force_constant_output: 72 | z = z[0:1].expand_as(z).clone() 73 | p_params = p_params[0:1].expand_as(p_params).clone() 74 | 75 | # Output of stochastic layer 76 | out = self.conv_out(z) 77 | 78 | # Compute log p(z) 79 | logprob_p = p.log_prob(z).sum((1, 2, 3)) 80 | 81 | if q_params is not None: 82 | 83 | # Compute log q(z) 84 | logprob_q = q.log_prob(z).sum((1, 2, 3)) 85 | 86 | # Compute KL (analytical or MC estimate) 87 | kl_analytical = kl_divergence(q, p) 88 | if analytical_kl: 89 | kl_elementwise = kl_analytical 90 | else: 91 | kl_elementwise = kl_normal_mc(z, p_params, q_params) 92 | kl_samplewise = kl_elementwise.sum((1, 2, 3)) 93 | 94 | # Compute spatial KL analytically (but conditioned on samples from 95 | # previous layers) 96 | kl_spatial_analytical = kl_analytical.sum(1) 97 | 98 | else: 99 | kl_elementwise = kl_samplewise = kl_spatial_analytical = None 100 | logprob_q = None 101 | 102 | data = { 103 | 'z': z, # sampled variable at this layer (batch, ch, h, w) 104 | 'p_params': p_params, # (b, ch, h, w) where b is 1 or batch size 105 | 'q_params': q_params, # (batch, ch, h, w) 106 | 'logprob_p': logprob_p, # (batch, ) 107 | 'logprob_q': logprob_q, # (batch, ) 108 | 'kl_elementwise': kl_elementwise, # (batch, ch, h, w) 109 | 'kl_samplewise': kl_samplewise, # (batch, ) 110 | 'kl_spatial': kl_spatial_analytical, # (batch, h, w) 111 | } 112 | return out, data 113 | 114 | 115 | def logistic_rsample(mu_ls): 116 | """ 117 | Returns a sample from Logistic with specified mean and log scale. 118 | :param mu_ls: a tensor containing mean and log scale along dim=1, 119 | or a tuple (mean, log scale) 120 | :return: a reparameterized sample with the same size as the input 121 | mean and log scale 122 | """ 123 | 124 | # Get parameters 125 | try: 126 | mu, log_scale = torch.chunk(mu_ls, 2, dim=1) 127 | except TypeError: 128 | mu, log_scale = mu_ls 129 | scale = log_scale.exp() 130 | 131 | # Get uniform sample in open interval (0, 1) 132 | u = torch.zeros_like(mu) 133 | u.uniform_(1e-7, 1 - 1e-7) 134 | 135 | # Transform into logistic sample 136 | sample = mu + scale * (torch.log(u) - torch.log(1 - u)) 137 | 138 | return sample 139 | 140 | 141 | def sample_from_discretized_mix_logistic(l): 142 | """ 143 | Code taken from pytorch adaptation of original PixelCNN++ tf implementation 144 | https://github.com/pclucas14/pixel-cnn-pp 145 | """ 146 | 147 | def to_one_hot(tensor, n): 148 | one_hot = torch.zeros(tensor.size() + (n,)) 149 | one_hot = one_hot.to(tensor.device) 150 | one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), 1.) 151 | return one_hot 152 | 153 | # Pytorch ordering 154 | l = l.permute(0, 2, 3, 1) 155 | ls = [int(y) for y in l.size()] 156 | xs = ls[:-1] + [3] 157 | 158 | # here and below: unpacking the params of the mixture of logistics 159 | nr_mix = int(ls[-1] / 10) 160 | 161 | # unpack parameters 162 | logit_probs = l[:, :, :, :nr_mix] 163 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3]) 164 | # sample mixture indicator from softmax 165 | temp = torch.FloatTensor(logit_probs.size()) 166 | if l.is_cuda: 167 | temp = temp.cuda() 168 | temp.uniform_(1e-5, 1. - 1e-5) 169 | temp = logit_probs.data - torch.log(-torch.log(temp)) 170 | _, argmax = temp.max(dim=3) 171 | 172 | one_hot = to_one_hot(argmax, nr_mix) 173 | sel = one_hot.view(xs[:-1] + [1, nr_mix]) 174 | # select logistic parameters 175 | means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4) 176 | log_scales = torch.clamp(torch.sum(l[:, :, :, :, nr_mix:2 * nr_mix] * sel, 177 | dim=4), 178 | min=-7.) 179 | coeffs = torch.sum(torch.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) * sel, 180 | dim=4) 181 | # sample from logistic & clip to interval 182 | # we don't actually round to the nearest 8bit value when sampling 183 | u = torch.FloatTensor(means.size()) 184 | if l.is_cuda: 185 | u = u.cuda() 186 | u.uniform_(1e-5, 1. - 1e-5) 187 | u = nn.Parameter(u) 188 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) 189 | x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-1.), max=1.) 190 | x1 = torch.clamp(torch.clamp(x[:, :, :, 1] + coeffs[:, :, :, 0] * x0, 191 | min=-1.), 192 | max=1.) 193 | x2 = torch.clamp(torch.clamp(x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + 194 | coeffs[:, :, :, 2] * x1, 195 | min=-1.), 196 | max=1.) 197 | 198 | out = torch.cat([ 199 | x0.view(xs[:-1] + [1]), 200 | x1.view(xs[:-1] + [1]), 201 | x2.view(xs[:-1] + [1]) 202 | ], 203 | dim=3) 204 | # put back in Pytorch ordering 205 | out = out.permute(0, 3, 1, 2) 206 | return out 207 | 208 | 209 | def kl_normal_mc(z, p_mulv, q_mulv): 210 | """ 211 | One-sample estimation of element-wise KL between two diagonal 212 | multivariate normal distributions. Any number of dimensions, 213 | broadcasting supported (be careful). 214 | 215 | :param z: 216 | :param p_mulv: 217 | :param q_mulv: 218 | :return: 219 | """ 220 | p_mu, p_lv = torch.chunk(p_mulv, 2, dim=1) 221 | q_mu, q_lv = torch.chunk(q_mulv, 2, dim=1) 222 | p_std = (p_lv / 2).exp() 223 | q_std = (q_lv / 2).exp() 224 | p_distrib = Normal(p_mu, p_std) 225 | q_distrib = Normal(q_mu, q_std) 226 | return q_distrib.log_prob(z) - p_distrib.log_prob(z) 227 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from boilr import Trainer 2 | 3 | from experiment import LVAEExperiment 4 | 5 | 6 | def main(): 7 | experiment = LVAEExperiment() 8 | trainer = Trainer(experiment) 9 | trainer.run() 10 | 11 | 12 | if __name__ == "__main__": 13 | main() 14 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/ladder-vae-pytorch/a27d45ce1c9b7b1cba49813f86c7fe99529179b5/models/__init__.py -------------------------------------------------------------------------------- /models/lvae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from boilr.models import BaseGenerativeModel 4 | from boilr.nn import crop_img_tensor, pad_img_tensor, Interpolate, free_bits_kl 5 | from torch import nn 6 | 7 | from lib.likelihoods import (BernoulliLikelihood, GaussianLikelihood, 8 | DiscretizedLogisticLikelihood, 9 | DiscretizedLogisticMixLikelihood) 10 | from .lvae_layers import (TopDownLayer, BottomUpLayer, 11 | TopDownDeterministicResBlock, 12 | BottomUpDeterministicResBlock) 13 | 14 | 15 | class LadderVAE(BaseGenerativeModel): 16 | 17 | def __init__(self, 18 | color_ch, 19 | z_dims, 20 | blocks_per_layer=2, 21 | downsample=None, 22 | nonlin='elu', 23 | merge_type=None, 24 | batchnorm=True, 25 | stochastic_skip=False, 26 | n_filters=32, 27 | dropout=None, 28 | free_bits=0.0, 29 | learn_top_prior=False, 30 | img_shape=None, 31 | likelihood_form=None, 32 | res_block_type=None, 33 | gated=False, 34 | no_initial_downscaling=False, 35 | analytical_kl=False): 36 | super().__init__() 37 | self.color_ch = color_ch 38 | self.z_dims = z_dims 39 | self.blocks_per_layer = blocks_per_layer 40 | self.downsample = downsample 41 | self.n_layers = len(self.z_dims) 42 | self.stochastic_skip = stochastic_skip 43 | self.n_filters = n_filters 44 | self.dropout = dropout 45 | self.free_bits = free_bits 46 | self.learn_top_prior = learn_top_prior 47 | self.img_shape = tuple(img_shape) 48 | self.res_block_type = res_block_type 49 | self.gated = gated 50 | 51 | # Default: no downsampling (except for initial bottom-up block) 52 | if self.downsample is None: 53 | self.downsample = [0] * self.n_layers 54 | 55 | # Downsample by a factor of 2 at each downsampling operation 56 | self.overall_downscale_factor = np.power(2, sum(self.downsample)) 57 | if not no_initial_downscaling: # by default do another downscaling 58 | self.overall_downscale_factor *= 2 59 | 60 | assert max(self.downsample) <= self.blocks_per_layer 61 | assert len(self.downsample) == self.n_layers 62 | 63 | # Get class of nonlinear activation from string description 64 | nonlin = { 65 | 'relu': nn.ReLU, 66 | 'leakyrelu': nn.LeakyReLU, 67 | 'elu': nn.ELU, 68 | 'selu': nn.SELU, 69 | }[nonlin] 70 | 71 | # First bottom-up layer: change num channels + downsample by factor 2 72 | # unless we want to prevent this 73 | stride = 1 if no_initial_downscaling else 2 74 | self.first_bottom_up = nn.Sequential( 75 | nn.Conv2d(color_ch, n_filters, 5, padding=2, stride=stride), 76 | nonlin(), 77 | BottomUpDeterministicResBlock( 78 | c_in=n_filters, 79 | c_out=n_filters, 80 | nonlin=nonlin, 81 | batchnorm=batchnorm, 82 | dropout=dropout, 83 | res_block_type=res_block_type, 84 | )) 85 | 86 | # Init lists of layers 87 | self.top_down_layers = nn.ModuleList([]) 88 | self.bottom_up_layers = nn.ModuleList([]) 89 | 90 | for i in range(self.n_layers): 91 | 92 | # Whether this is the top layer 93 | is_top = i == self.n_layers - 1 94 | 95 | # Add bottom-up deterministic layer at level i. 96 | # It's a sequence of residual blocks (BottomUpDeterministicResBlock) 97 | # possibly with downsampling between them. 98 | self.bottom_up_layers.append( 99 | BottomUpLayer( 100 | n_res_blocks=self.blocks_per_layer, 101 | n_filters=n_filters, 102 | downsampling_steps=self.downsample[i], 103 | nonlin=nonlin, 104 | batchnorm=batchnorm, 105 | dropout=dropout, 106 | res_block_type=res_block_type, 107 | gated=gated, 108 | )) 109 | 110 | # Add top-down stochastic layer at level i. 111 | # The architecture when doing inference is roughly as follows: 112 | # p_params = output of top-down layer above 113 | # bu = inferred bottom-up value at this layer 114 | # q_params = merge(bu, p_params) 115 | # z = stochastic_layer(q_params) 116 | # possibly get skip connection from previous top-down layer 117 | # top-down deterministic ResNet 118 | # 119 | # When doing generation only, the value bu is not available, the 120 | # merge layer is not used, and z is sampled directly from p_params. 121 | # 122 | self.top_down_layers.append( 123 | TopDownLayer( 124 | z_dim=z_dims[i], 125 | n_res_blocks=blocks_per_layer, 126 | n_filters=n_filters, 127 | is_top_layer=is_top, 128 | downsampling_steps=downsample[i], 129 | nonlin=nonlin, 130 | merge_type=merge_type, 131 | batchnorm=batchnorm, 132 | dropout=dropout, 133 | stochastic_skip=stochastic_skip, 134 | learn_top_prior=learn_top_prior, 135 | top_prior_param_shape=self.get_top_prior_param_shape(), 136 | res_block_type=res_block_type, 137 | gated=gated, 138 | analytical_kl=analytical_kl, 139 | )) 140 | 141 | # Final top-down layer 142 | modules = list() 143 | if not no_initial_downscaling: 144 | modules.append(Interpolate(scale=2)) 145 | for i in range(blocks_per_layer): 146 | modules.append( 147 | TopDownDeterministicResBlock( 148 | c_in=n_filters, 149 | c_out=n_filters, 150 | nonlin=nonlin, 151 | batchnorm=batchnorm, 152 | dropout=dropout, 153 | res_block_type=res_block_type, 154 | gated=gated, 155 | )) 156 | self.final_top_down = nn.Sequential(*modules) 157 | 158 | # Define likelihood 159 | if likelihood_form == 'bernoulli': 160 | self.likelihood = BernoulliLikelihood(n_filters, color_ch) 161 | elif likelihood_form == 'gaussian': 162 | self.likelihood = GaussianLikelihood(n_filters, color_ch) 163 | elif likelihood_form == 'discr_log': 164 | self.likelihood = DiscretizedLogisticLikelihood( 165 | n_filters, color_ch, 256) 166 | elif likelihood_form == 'discr_log_mix': 167 | self.likelihood = DiscretizedLogisticMixLikelihood(n_filters) 168 | else: 169 | msg = "Unrecognized likelihood '{}'".format(likelihood_form) 170 | raise RuntimeError(msg) 171 | 172 | def forward(self, x): 173 | img_size = x.size()[2:] 174 | 175 | # Pad input to make everything easier with conv strides 176 | x_pad = self.pad_input(x) 177 | 178 | # Bottom-up inference: return list of length n_layers (bottom to top) 179 | bu_values = self.bottomup_pass(x_pad) 180 | 181 | # Top-down inference/generation 182 | out, td_data = self.topdown_pass(bu_values) 183 | 184 | # Restore original image size 185 | out = crop_img_tensor(out, img_size) 186 | 187 | # Log likelihood and other info (per data point) 188 | ll, likelihood_info = self.likelihood(out, x) 189 | 190 | # kl[i] for each i has length batch_size 191 | # resulting kl shape: (batch_size, layers) 192 | kl = torch.cat([kl_layer.unsqueeze(1) for kl_layer in td_data['kl']], 193 | dim=1) 194 | 195 | kl_sep = kl.sum(1) 196 | kl_avg_layerwise = kl.mean(0) 197 | kl_loss = free_bits_kl(kl, self.free_bits).sum() # sum over layers 198 | kl = kl_sep.mean() 199 | 200 | output = { 201 | 'll': ll, 202 | 'z': td_data['z'], 203 | 'kl': kl, 204 | 'kl_sep': kl_sep, 205 | 'kl_avg_layerwise': kl_avg_layerwise, 206 | 'kl_spatial': td_data['kl_spatial'], 207 | 'kl_loss': kl_loss, 208 | 'logp': td_data['logprob_p'], 209 | 'out_mean': likelihood_info['mean'], 210 | 'out_mode': likelihood_info['mode'], 211 | 'out_sample': likelihood_info['sample'], 212 | 'likelihood_params': likelihood_info['params'] 213 | } 214 | return output 215 | 216 | def bottomup_pass(self, x): 217 | # Bottom-up initial layer 218 | x = self.first_bottom_up(x) 219 | 220 | # Loop from bottom to top layer, store all deterministic nodes we 221 | # need in the top-down pass 222 | bu_values = [] 223 | for i in range(self.n_layers): 224 | x = self.bottom_up_layers[i](x) 225 | bu_values.append(x) 226 | 227 | return bu_values 228 | 229 | def topdown_pass(self, 230 | bu_values=None, 231 | n_img_prior=None, 232 | mode_layers=None, 233 | constant_layers=None, 234 | forced_latent=None): 235 | 236 | # Default: no layer is sampled from the distribution's mode 237 | if mode_layers is None: 238 | mode_layers = [] 239 | if constant_layers is None: 240 | constant_layers = [] 241 | prior_experiment = len(mode_layers) > 0 or len(constant_layers) > 0 242 | 243 | # If the bottom-up inference values are not given, don't do 244 | # inference, sample from prior instead 245 | inference_mode = bu_values is not None 246 | 247 | # Check consistency of arguments 248 | if inference_mode != (n_img_prior is None): 249 | msg = ("Number of images for top-down generation has to be given " 250 | "if and only if we're not doing inference") 251 | raise RuntimeError(msg) 252 | if inference_mode and prior_experiment: 253 | msg = ("Prior experiments (e.g. sampling from mode) are not" 254 | " compatible with inference mode") 255 | raise RuntimeError(msg) 256 | 257 | # Sampled latent variables at each layer 258 | z = [None] * self.n_layers 259 | 260 | # KL divergence of each layer 261 | kl = [None] * self.n_layers 262 | 263 | # Spatial map of KL divergence for each layer 264 | kl_spatial = [None] * self.n_layers 265 | 266 | if forced_latent is None: 267 | forced_latent = [None] * self.n_layers 268 | 269 | # log p(z) where z is the sample in the topdown pass 270 | logprob_p = 0. 271 | 272 | # Top-down inference/generation loop 273 | out = out_pre_residual = None 274 | for i in reversed(range(self.n_layers)): 275 | 276 | # If available, get deterministic node from bottom-up inference 277 | try: 278 | bu_value = bu_values[i] 279 | except TypeError: 280 | bu_value = None 281 | 282 | # Whether the current layer should be sampled from the mode 283 | use_mode = i in mode_layers 284 | constant_out = i in constant_layers 285 | 286 | # Input for skip connection 287 | skip_input = out # TODO or out_pre_residual? or both? 288 | 289 | # Full top-down layer, including sampling and deterministic part 290 | out, out_pre_residual, aux = self.top_down_layers[i]( 291 | out, 292 | skip_connection_input=skip_input, 293 | inference_mode=inference_mode, 294 | bu_value=bu_value, 295 | n_img_prior=n_img_prior, 296 | use_mode=use_mode, 297 | force_constant_output=constant_out, 298 | forced_latent=forced_latent[i], 299 | ) 300 | z[i] = aux['z'] # sampled variable at this layer (batch, ch, h, w) 301 | kl[i] = aux['kl_samplewise'] # (batch, ) 302 | kl_spatial[i] = aux['kl_spatial'] # (batch, h, w) 303 | logprob_p += aux['logprob_p'].mean() # mean over batch 304 | 305 | # Final top-down layer 306 | out = self.final_top_down(out) 307 | 308 | data = { 309 | 'z': z, # list of tensors with shape (batch, ch[i], h[i], w[i]) 310 | 'kl': kl, # list of tensors with shape (batch, ) 311 | 'kl_spatial': 312 | kl_spatial, # list of tensors w shape (batch, h[i], w[i]) 313 | 'logprob_p': logprob_p, # scalar, mean over batch 314 | } 315 | return out, data 316 | 317 | def pad_input(self, x): 318 | """ 319 | Pads input x so that its sizes are powers of 2 320 | :param x: 321 | :return: Padded tensor 322 | """ 323 | size = self.get_padded_size(x.size()) 324 | x = pad_img_tensor(x, size) 325 | return x 326 | 327 | def get_padded_size(self, size): 328 | """ 329 | Returns the smallest size (H, W) of the image with actual size given 330 | as input, such that H and W are powers of 2. 331 | :param size: input size, tuple either (N, C, H, w) or (H, W) 332 | :return: 2-tuple (H, W) 333 | """ 334 | 335 | # Overall downscale factor from input to top layer (power of 2) 336 | dwnsc = self.overall_downscale_factor 337 | 338 | # Make size argument into (heigth, width) 339 | if len(size) == 4: 340 | size = size[2:] 341 | if len(size) != 2: 342 | msg = ("input size must be either (N, C, H, W) or (H, W), but it " 343 | "has length {} (size={})".format(len(size), size)) 344 | raise RuntimeError(msg) 345 | 346 | # Output smallest powers of 2 that are larger than current sizes 347 | padded_size = list(((s - 1) // dwnsc + 1) * dwnsc for s in size) 348 | 349 | return padded_size 350 | 351 | def sample_prior(self, n_imgs, mode_layers=None, constant_layers=None): 352 | 353 | # Generate from prior 354 | out, _ = self.topdown_pass(n_img_prior=n_imgs, 355 | mode_layers=mode_layers, 356 | constant_layers=constant_layers) 357 | out = crop_img_tensor(out, self.img_shape) 358 | 359 | # Log likelihood and other info (per data point) 360 | _, likelihood_data = self.likelihood(out, None) 361 | 362 | return likelihood_data['sample'] 363 | 364 | def get_top_prior_param_shape(self, n_imgs=1): 365 | # TODO num channels depends on random variable we're using 366 | dwnsc = self.overall_downscale_factor 367 | sz = self.get_padded_size(self.img_shape) 368 | h = sz[0] // dwnsc 369 | w = sz[1] // dwnsc 370 | c = self.z_dims[-1] * 2 # mu and logvar 371 | top_layer_shape = (n_imgs, c, h, w) 372 | return top_layer_shape 373 | -------------------------------------------------------------------------------- /models/lvae_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from lib.nn import ResidualBlock, ResidualGatedBlock 5 | from lib.stochastic import NormalStochasticBlock2d 6 | 7 | 8 | class TopDownLayer(nn.Module): 9 | """ 10 | Top-down layer, including stochastic sampling, KL computation, and small 11 | deterministic ResNet with upsampling. 12 | 13 | The architecture when doing inference is roughly as follows: 14 | p_params = output of top-down layer above 15 | bu = inferred bottom-up value at this layer 16 | q_params = merge(bu, p_params) 17 | z = stochastic_layer(q_params) 18 | possibly get skip connection from previous top-down layer 19 | top-down deterministic ResNet 20 | 21 | When doing generation only, the value bu is not available, the 22 | merge layer is not used, and z is sampled directly from p_params. 23 | 24 | If this is the top layer, at inference time, the uppermost bottom-up value 25 | is used directly as q_params, and p_params are defined in this layer 26 | (while they are usually taken from the previous layer), and can be learned. 27 | """ 28 | 29 | def __init__(self, 30 | z_dim, 31 | n_res_blocks, 32 | n_filters, 33 | is_top_layer=False, 34 | downsampling_steps=None, 35 | nonlin=None, 36 | merge_type=None, 37 | batchnorm=True, 38 | dropout=None, 39 | stochastic_skip=False, 40 | res_block_type=None, 41 | gated=None, 42 | learn_top_prior=False, 43 | top_prior_param_shape=None, 44 | analytical_kl=False): 45 | 46 | super().__init__() 47 | 48 | self.is_top_layer = is_top_layer 49 | self.z_dim = z_dim 50 | self.stochastic_skip = stochastic_skip 51 | self.learn_top_prior = learn_top_prior 52 | self.analytical_kl = analytical_kl 53 | 54 | # Define top layer prior parameters, possibly learnable 55 | if is_top_layer: 56 | self.top_prior_params = nn.Parameter( 57 | torch.zeros(top_prior_param_shape), 58 | requires_grad=learn_top_prior) 59 | 60 | # Downsampling steps left to do in this layer 61 | dws_left = downsampling_steps 62 | 63 | # Define deterministic top-down block: sequence of deterministic 64 | # residual blocks with downsampling when needed. 65 | block_list = [] 66 | for _ in range(n_res_blocks): 67 | do_resample = False 68 | if dws_left > 0: 69 | do_resample = True 70 | dws_left -= 1 71 | block_list.append( 72 | TopDownDeterministicResBlock( 73 | n_filters, 74 | n_filters, 75 | nonlin, 76 | upsample=do_resample, 77 | batchnorm=batchnorm, 78 | dropout=dropout, 79 | res_block_type=res_block_type, 80 | gated=gated, 81 | )) 82 | self.deterministic_block = nn.Sequential(*block_list) 83 | 84 | # Define stochastic block with 2d convolutions 85 | self.stochastic = NormalStochasticBlock2d( 86 | c_in=n_filters, 87 | c_vars=z_dim, 88 | c_out=n_filters, 89 | transform_p_params=(not is_top_layer), 90 | ) 91 | 92 | if not is_top_layer: 93 | 94 | # Merge layer, combine bottom-up inference with top-down 95 | # generative to give posterior parameters 96 | self.merge = MergeLayer( 97 | channels=n_filters, 98 | merge_type=merge_type, 99 | nonlin=nonlin, 100 | batchnorm=batchnorm, 101 | dropout=dropout, 102 | res_block_type=res_block_type, 103 | ) 104 | 105 | # Skip connection that goes around the stochastic top-down layer 106 | if stochastic_skip: 107 | self.skip_connection_merger = SkipConnectionMerger( 108 | channels=n_filters, 109 | nonlin=nonlin, 110 | batchnorm=batchnorm, 111 | dropout=dropout, 112 | res_block_type=res_block_type, 113 | ) 114 | 115 | def forward(self, 116 | input_=None, 117 | skip_connection_input=None, 118 | inference_mode=False, 119 | bu_value=None, 120 | n_img_prior=None, 121 | forced_latent=None, 122 | use_mode=False, 123 | force_constant_output=False): 124 | 125 | # Check consistency of arguments 126 | inputs_none = input_ is None and skip_connection_input is None 127 | if self.is_top_layer and not inputs_none: 128 | raise ValueError("In top layer, inputs should be None") 129 | 130 | # If top layer, define parameters of prior p(z_L) 131 | if self.is_top_layer: 132 | p_params = self.top_prior_params 133 | 134 | # Sample specific number of images by expanding the prior 135 | if n_img_prior is not None: 136 | p_params = p_params.expand(n_img_prior, -1, -1, -1) 137 | 138 | # Else the input from the layer above is the prior parameters 139 | else: 140 | p_params = input_ 141 | 142 | # In inference mode, get parameters of q from inference path, 143 | # merging with top-down path if it's not the top layer 144 | if inference_mode: 145 | if self.is_top_layer: 146 | q_params = bu_value 147 | else: 148 | q_params = self.merge(bu_value, p_params) 149 | 150 | # In generative mode, q is not used 151 | else: 152 | q_params = None 153 | 154 | # Sample from either q(z_i | z_{i+1}, x) or p(z_i | z_{i+1}) 155 | # depending on whether q_params is None 156 | x, data_stoch = self.stochastic( 157 | p_params=p_params, 158 | q_params=q_params, 159 | forced_latent=forced_latent, 160 | use_mode=use_mode, 161 | force_constant_output=force_constant_output, 162 | analytical_kl=self.analytical_kl, 163 | ) 164 | 165 | # Skip connection from previous layer 166 | if self.stochastic_skip and not self.is_top_layer: 167 | x = self.skip_connection_merger(x, skip_connection_input) 168 | 169 | # Save activation before residual block: could be the skip 170 | # connection input in the next layer 171 | x_pre_residual = x 172 | 173 | # Last top-down block (sequence of residual blocks) 174 | x = self.deterministic_block(x) 175 | 176 | keys = ['z', 'kl_samplewise', 'kl_spatial', 'logprob_p', 'logprob_q'] 177 | data = {k: data_stoch[k] for k in keys} 178 | return x, x_pre_residual, data 179 | 180 | 181 | class BottomUpLayer(nn.Module): 182 | """ 183 | Bottom-up deterministic layer for inference, roughly the same as the 184 | small deterministic Resnet in top-down layers. Consists of a sequence of 185 | bottom-up deterministic residual blocks with downsampling. 186 | """ 187 | 188 | def __init__(self, 189 | n_res_blocks, 190 | n_filters, 191 | downsampling_steps=0, 192 | nonlin=None, 193 | batchnorm=True, 194 | dropout=None, 195 | res_block_type=None, 196 | gated=None): 197 | super().__init__() 198 | 199 | bu_blocks = [] 200 | for _ in range(n_res_blocks): 201 | do_resample = False 202 | if downsampling_steps > 0: 203 | do_resample = True 204 | downsampling_steps -= 1 205 | bu_blocks.append( 206 | BottomUpDeterministicResBlock( 207 | c_in=n_filters, 208 | c_out=n_filters, 209 | nonlin=nonlin, 210 | downsample=do_resample, 211 | batchnorm=batchnorm, 212 | dropout=dropout, 213 | res_block_type=res_block_type, 214 | gated=gated, 215 | )) 216 | self.net = nn.Sequential(*bu_blocks) 217 | 218 | def forward(self, x): 219 | return self.net(x) 220 | 221 | 222 | class ResBlockWithResampling(nn.Module): 223 | """ 224 | Residual block that takes care of resampling steps (each by a factor of 2). 225 | 226 | The mode can be top-down or bottom-up, and the block does up- and 227 | down-sampling by a factor of 2, respectively. Resampling is performed at 228 | the beginning of the block, through strided convolution. 229 | 230 | The number of channels is adjusted at the beginning and end of the block, 231 | through convolutional layers with kernel size 1. The number of internal 232 | channels is by default the same as the number of output channels, but 233 | min_inner_channels overrides this behaviour. 234 | 235 | Other parameters: kernel size, nonlinearity, and groups of the internal 236 | residual block; whether batch normalization and dropout are performed; 237 | whether the residual path has a gate layer at the end. There are a few 238 | residual block structures to choose from. 239 | """ 240 | 241 | def __init__(self, 242 | mode, 243 | c_in, 244 | c_out, 245 | nonlin=nn.LeakyReLU, 246 | resample=False, 247 | res_block_kernel=None, 248 | groups=1, 249 | batchnorm=True, 250 | res_block_type=None, 251 | dropout=None, 252 | min_inner_channels=None, 253 | gated=None): 254 | super().__init__() 255 | assert mode in ['top-down', 'bottom-up'] 256 | if min_inner_channels is None: 257 | min_inner_channels = 0 258 | inner_filters = max(c_out, min_inner_channels) 259 | 260 | # Define first conv layer to change channels and/or up/downsample 261 | if resample: 262 | if mode == 'bottom-up': # downsample 263 | self.pre_conv = nn.Conv2d(in_channels=c_in, 264 | out_channels=inner_filters, 265 | kernel_size=3, 266 | padding=1, 267 | stride=2, 268 | groups=groups) 269 | elif mode == 'top-down': # upsample 270 | self.pre_conv = nn.ConvTranspose2d(in_channels=c_in, 271 | out_channels=inner_filters, 272 | kernel_size=3, 273 | padding=1, 274 | stride=2, 275 | groups=groups, 276 | output_padding=1) 277 | elif c_in != inner_filters: 278 | self.pre_conv = nn.Conv2d(c_in, inner_filters, 1, groups=groups) 279 | else: 280 | self.pre_conv = None 281 | 282 | # Residual block 283 | self.res = ResidualBlock( 284 | channels=inner_filters, 285 | nonlin=nonlin, 286 | kernel=res_block_kernel, 287 | groups=groups, 288 | batchnorm=batchnorm, 289 | dropout=dropout, 290 | gated=gated, 291 | block_type=res_block_type, 292 | ) 293 | 294 | # Define last conv layer to get correct num output channels 295 | if inner_filters != c_out: 296 | self.post_conv = nn.Conv2d(inner_filters, c_out, 1, groups=groups) 297 | else: 298 | self.post_conv = None 299 | 300 | def forward(self, x): 301 | if self.pre_conv is not None: 302 | x = self.pre_conv(x) 303 | x = self.res(x) 304 | if self.post_conv is not None: 305 | x = self.post_conv(x) 306 | return x 307 | 308 | 309 | class TopDownDeterministicResBlock(ResBlockWithResampling): 310 | 311 | def __init__(self, *args, upsample=False, **kwargs): 312 | kwargs['resample'] = upsample 313 | super().__init__('top-down', *args, **kwargs) 314 | 315 | 316 | class BottomUpDeterministicResBlock(ResBlockWithResampling): 317 | 318 | def __init__(self, *args, downsample=False, **kwargs): 319 | kwargs['resample'] = downsample 320 | super().__init__('bottom-up', *args, **kwargs) 321 | 322 | 323 | class MergeLayer(nn.Module): 324 | """ 325 | Merge two 4D input tensors by concatenating along dim=1 and passing the 326 | result through 1) a convolutional 1x1 layer, or 2) a residual block 327 | """ 328 | 329 | def __init__(self, 330 | channels, 331 | merge_type, 332 | nonlin=nn.LeakyReLU, 333 | batchnorm=True, 334 | dropout=None, 335 | res_block_type=None): 336 | super().__init__() 337 | try: 338 | iter(channels) 339 | except TypeError: # it is not iterable 340 | channels = [channels] * 3 341 | else: # it is iterable 342 | if len(channels) == 1: 343 | channels = [channels[0]] * 3 344 | assert len(channels) == 3 345 | 346 | if merge_type == 'linear': 347 | self.layer = nn.Conv2d(channels[0] + channels[1], channels[2], 1) 348 | elif merge_type == 'residual': 349 | self.layer = nn.Sequential( 350 | nn.Conv2d(channels[0] + channels[1], channels[2], 1, padding=0), 351 | ResidualGatedBlock(channels[2], 352 | nonlin, 353 | batchnorm=batchnorm, 354 | dropout=dropout, 355 | block_type=res_block_type), 356 | ) 357 | 358 | def forward(self, x, y): 359 | x = torch.cat((x, y), dim=1) 360 | return self.layer(x) 361 | 362 | 363 | class SkipConnectionMerger(MergeLayer): 364 | """ 365 | By default for now simply a merge layer. 366 | """ 367 | 368 | merge_type = 'residual' 369 | 370 | def __init__(self, channels, nonlin, batchnorm, dropout, res_block_type): 371 | super().__init__(channels, 372 | self.merge_type, 373 | nonlin, 374 | batchnorm, 375 | dropout=dropout, 376 | res_block_type=res_block_type) 377 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.18.1,<=1.19.0 2 | torch>=1.4.0,<=1.5.1 3 | torchvision>=0.5.0,<=0.6.1 4 | matplotlib>=3.1.2 5 | seaborn>=0.9.0,<=0.10.1 6 | boilr==0.7.4 7 | multiobject==0.0.3 8 | --------------------------------------------------------------------------------