├── .gitignore ├── LICENSE ├── README.md ├── biva ├── __init__.py ├── datasets │ ├── __init__.py │ ├── binmnist.py │ └── cifar10.py ├── evaluation │ ├── __init__.py │ ├── freebits.py │ └── vi.py ├── layers │ ├── __init__.py │ ├── block.py │ ├── convolution.py │ └── linear.py ├── model │ ├── __init__.py │ ├── architectures.py │ ├── deepvae.py │ ├── stage.py │ ├── stochastic.py │ └── utils.py └── utils │ ├── __init__.py │ ├── discretized_mixture_logits.py │ ├── logging.py │ ├── ops.py │ ├── restore.py │ └── utils.py ├── example.py ├── load_deepvae.py ├── requirements.txt ├── run_deepvae.py ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | runs/ 3 | *.egg-info 4 | *dist/ 5 | .DS_store 6 | .idea/ 7 | __pycache__/ 8 | build/ 9 | .ipynb_checkpoints/ 10 | 11 | *.ipynb 12 | /output/ 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Valentin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BIVA (PyTorch) 2 | 3 | Official PyTorch BIVA implementation (BIVA: A Very Deep Hierarchy of Latent Variables forGenerative Modeling) for binarized MNIST and CIFAR. The original Tensorflow implementation can be found [here](https://github.com/larsmaaloee/BIVA). 4 | 5 | ## Run the Experiments 6 | 7 | ```bash 8 | conda create --name biva python=3.7 9 | conda activate biva 10 | pip install -r requirements.txt 11 | CUDA_VISIBLE_DEVICES=0 python run_deepvae.py --dataset binmnist --q_dropout 0.5 --p_dropout 0.5 --device cuda 12 | CUDA_VISIBLE_DEVICES=0 python run_deepvae.py --dataset cifar10 --q_dropout 0.2 --p_dropout 0 --device cuda 13 | ``` 14 | 15 | ## Citation 16 | 17 | ``` 18 | @article{maale2019biva, 19 | title={BIVA: A Very Deep Hierarchy of Latent Variables for Generative Modeling}, 20 | author={Lars Maaløe and Marco Fraccaro and Valentin Liévin and Ole Winther}, 21 | year={2019}, 22 | eprint={1902.02102}, 23 | archivePrefix={arXiv}, 24 | primaryClass={stat.ML} 25 | } 26 | ``` 27 | 28 | ## Pip package 29 | 30 | ### Install the Requirements 31 | 32 | * `pytorch 1.3.0` 33 | * `torchvision` 34 | * `matplotlib` 35 | * `tensorboard` 36 | * `booster-pytorch==0.0.2` 37 | 38 | ### Install as a Package 39 | 40 | ```bash 41 | pip install git+https://github.com/vlievin/biva-pytorch.git 42 | ``` 43 | 44 | ### Build Deep VAEs 45 | 46 | ```python 47 | import torch 48 | from torch.distributions import Bernoulli 49 | 50 | from biva import DenseNormal, ConvNormal 51 | from biva import VAE, LVAE, BIVA 52 | 53 | # build a 2 layers VAE for binary images 54 | 55 | # define the stochastic layers 56 | z = [ 57 | {'N': 8, 'kernel': 5, 'block': ConvNormal}, # z1 58 | {'N': 16, 'block': DenseNormal} # z2 59 | ] 60 | 61 | # define the intermediate layers 62 | # each stage defines the configuration of the blocks for q_(z_{l} | z_{l-1}) and p_(z_{l-1} | z_{l}) 63 | # each stage is defined by a sequence of 3 resnet blocks 64 | # each block is degined by a tuple [filters, kernel, stride] 65 | stages = [ 66 | [[64, 3, 1], [64, 3, 1], [64, 3, 2]], 67 | [[64, 3, 1], [64, 3, 1], [64, 3, 2]] 68 | ] 69 | 70 | # build the model 71 | model = VAE(tensor_shp=(-1, 1, 28, 28), stages=stages, latents=z, dropout=0.5) 72 | 73 | # forward pass and data-dependent initialization 74 | x = torch.empty((8, 1, 28, 28)).uniform_().bernoulli() 75 | data = model(x) # data = {'x_' : p(x|z), z \sim q(z|x), 'kl': [kl_z1, kl_z2]} 76 | 77 | # sample from prior 78 | data = model.sample_from_prior(N=16) # data = {'x_' : p(x|z), z \sim p(z)} 79 | samples = Bernoulli(logits=data['x_']).sample() 80 | 81 | ``` 82 | -------------------------------------------------------------------------------- /biva/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["datasets", "evaluation", "model"] 2 | 3 | from .model import * 4 | -------------------------------------------------------------------------------- /biva/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .binmnist import get_binmnist_datasets 2 | from .cifar10 import get_cifar10_datasets 3 | -------------------------------------------------------------------------------- /biva/datasets/binmnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pkl 3 | from urllib.request import urlretrieve 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | 10 | def load_mnist_binarized(root): 11 | datapath = os.path.join(root, 'bin-mnist') 12 | if not os.path.exists(datapath): 13 | os.makedirs(datapath) 14 | dataset = os.path.join(datapath, "mnist.pkl.gz") 15 | 16 | if not os.path.isfile(dataset): 17 | 18 | datafiles = { 19 | "train": "http://www.cs.toronto.edu/~larocheh/public/" 20 | "datasets/binarized_mnist/binarized_mnist_train.amat", 21 | "valid": "http://www.cs.toronto.edu/~larocheh/public/datasets/" 22 | "binarized_mnist/binarized_mnist_valid.amat", 23 | "test": "http://www.cs.toronto.edu/~larocheh/public/datasets/" 24 | "binarized_mnist/binarized_mnist_test.amat" 25 | } 26 | datasplits = {} 27 | for split in datafiles.keys(): 28 | print("Downloading %s data..." % (split)) 29 | datasplits[split] = np.loadtxt(urlretrieve(datafiles[split])[0]) 30 | 31 | pkl.dump([datasplits['train'], datasplits['valid'], datasplits['test']], open(dataset, "wb")) 32 | 33 | x_train, x_valid, x_test = pkl.load(open(dataset, "rb")) 34 | return x_train, x_valid, x_test 35 | 36 | 37 | class BinMNIST(Dataset): 38 | """Binary MNIST dataset""" 39 | 40 | def __init__(self, data, device='cpu', transform=None): 41 | h, w, c = 28, 28, 1 42 | self.device = device 43 | self.data = torch.tensor(data, dtype=torch.float).view(-1, c, h, w) 44 | self.transform = transform 45 | 46 | def __len__(self): 47 | return len(self.data) 48 | 49 | def __getitem__(self, idx): 50 | sample = self.data[idx] 51 | if self.transform: 52 | sample = self.transform(sample) 53 | return sample.to(self.device) 54 | 55 | 56 | def get_binmnist_datasets(root, device='cpu'): 57 | x_train, x_valid, x_test = load_mnist_binarized(root) 58 | x_train = np.append(x_train, x_valid, axis=0) # https://github.com/casperkaae/LVAE/blob/master/run_models.py (line 401) 59 | return BinMNIST(x_train, device=device), BinMNIST(x_test, device=device), BinMNIST(x_test, device=device) 60 | -------------------------------------------------------------------------------- /biva/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pkl 3 | import tarfile 4 | from urllib.request import urlretrieve 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | def quantisize(images, levels): 11 | return (np.digitize(images, np.arange(levels) / levels) - 1).astype('i') 12 | 13 | 14 | def load_cifar(root, levels=256, with_y=False): 15 | dataset = 'cifar-10-python.tar.gz' 16 | data_dir, data_file = os.path.split(dataset) 17 | if data_dir == "" and not os.path.isfile(dataset): 18 | # Check if dataset is in the data directory. 19 | new_path = os.path.join(root, dataset) 20 | if os.path.isfile(new_path) or data_file == 'cifar-10-python.tar.gz': 21 | dataset = new_path 22 | 23 | if (not os.path.isfile(dataset)) and data_file == 'cifar-10-python.tar.gz': 24 | origin = ( 25 | 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 26 | ) 27 | print("Downloading data from {}...".format(origin)) 28 | urlretrieve(origin, dataset) 29 | 30 | f = tarfile.open(dataset, 'r:gz') 31 | b1 = pkl.load(f.extractfile("cifar-10-batches-py/data_batch_1"), encoding="bytes") 32 | b2 = pkl.load(f.extractfile("cifar-10-batches-py/data_batch_2"), encoding="bytes") 33 | b3 = pkl.load(f.extractfile("cifar-10-batches-py/data_batch_3"), encoding="bytes") 34 | b4 = pkl.load(f.extractfile("cifar-10-batches-py/data_batch_4"), encoding="bytes") 35 | b5 = pkl.load(f.extractfile("cifar-10-batches-py/data_batch_5"), encoding="bytes") 36 | test = pkl.load(f.extractfile("cifar-10-batches-py/test_batch"), encoding="bytes") 37 | train_x = np.concatenate([b1[b'data'], b2[b'data'], b3[b'data'], b4[b'data'], b5[b'data']], axis=0) / 255. 38 | train_x = np.asarray(train_x, dtype='float32') 39 | train_t = np.concatenate([np.array(b1[b'labels']), 40 | np.array(b2[b'labels']), 41 | np.array(b3[b'labels']), 42 | np.array(b4[b'labels']), 43 | np.array(b5[b'labels'])], axis=0) 44 | 45 | test_x = test[b'data'] / 255. 46 | test_x = np.asarray(test_x, dtype='float32') 47 | test_t = np.array(test[b'labels']) 48 | f.close() 49 | 50 | train_x = train_x.reshape((train_x.shape[0], 3, 32, 32)) 51 | test_x = test_x.reshape((test_x.shape[0], 3, 32, 32)) 52 | train_x = quantisize(train_x, levels) / (levels - 1.) 53 | test_x = quantisize(test_x, levels) / (levels - 1.) 54 | 55 | if with_y: 56 | return (train_x, train_t), (test_x, test_t) 57 | return train_x, test_x 58 | 59 | 60 | class Cifar10Dataset(Dataset): 61 | """Binary MNIST dataset""" 62 | 63 | def __init__(self, data, device='cpu', transform=None): 64 | self.device = device 65 | self.data = torch.tensor(data, dtype=torch.float) 66 | self.transform = transform 67 | 68 | def __len__(self): 69 | return len(self.data) 70 | 71 | def __getitem__(self, idx): 72 | sample = self.data[idx] 73 | if self.transform: 74 | sample = self.transform(sample) 75 | return sample.to(self.device) 76 | 77 | 78 | def get_cifar10_datasets(root, levels=256, **kwargs): 79 | x_train, x_test = load_cifar(root, levels=levels) 80 | return Cifar10Dataset(x_train, **kwargs), Cifar10Dataset(x_test, **kwargs), Cifar10Dataset(x_test, **kwargs) -------------------------------------------------------------------------------- /biva/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .vi import VariationalInference 2 | -------------------------------------------------------------------------------- /biva/evaluation/freebits.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class FreeBits(): 6 | """ 7 | free bits: https://arxiv.org/abs/1606.04934 8 | Assumes a each of the dimension to be one group 9 | """ 10 | 11 | def __init__(self, min_KL: float): 12 | self.min_KL = min_KL 13 | 14 | def __call__(self, kls: torch.Tensor) -> torch.Tensor: 15 | """ 16 | Apply freebits over tensor. The freebits budget is distributed equally among dimensions. 17 | The returned freebits KL is equal to max(kl, freebits_per_dim, dim = >0) 18 | :param kls: KL of shape [batch size, *dimensions] 19 | :return: freebits KL of shape [batch size, *dimensions] 20 | """ 21 | 22 | # equally divide freebits budget over the dimensions 23 | dimensions = np.prod(kls.shape[1:]) 24 | min_KL_per_dim = self.min_KL / dimensions if len(kls.shape) > 1 else self.min_KL 25 | min_KL_per_dim = min_KL_per_dim * torch.ones_like(kls) 26 | 27 | # apply freebits 28 | freebits_kl = torch.cat([kls.unsqueeze(-1), min_KL_per_dim.unsqueeze(-1)], -1) 29 | freebits_kl = torch.max(freebits_kl, dim=-1)[0] 30 | 31 | return freebits_kl 32 | -------------------------------------------------------------------------------- /biva/evaluation/vi.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | from typing import * 4 | 5 | import numpy as np 6 | import torch 7 | from booster import Diagnostic 8 | from torch import Tensor, nn 9 | 10 | from .freebits import FreeBits 11 | from ..utils import batch_reduce, log_sum_exp, detach_to_device 12 | 13 | 14 | class VariationalInference(object): 15 | def __init__(self, likelihood: any, iw_samples: int = 1, auxiliary: Dict[str, float] = {}, **parameters: Any): 16 | """ 17 | Initialize VI evaluator 18 | :param likelihood: likelihood class used to evaluate log p(x | z) 19 | :param iw_samples: number of importance-weighted samples 20 | :param auxiliary: dict of auxiliary losses, each key must have a match in the model output, the auxiliary values define the weight of the auxiliary loss in the overall loss 21 | :param parameters: additional parameters passed to model and evaluator 22 | """ 23 | super().__init__() 24 | 25 | assert iw_samples > 0 26 | self._iw_samples = iw_samples 27 | self._parameters = parameters 28 | self._auxiliary = auxiliary 29 | self.likelihood = likelihood 30 | 31 | @staticmethod 32 | def compute_kls(kls: Union[Tensor, List[Tensor]], freebits: Optional[Union[float, List[float]]], device: str): 33 | """compute kl and kl to be accounted in the loss""" 34 | 35 | if kls is None or (isinstance(kls, list) and len(kls) == 0): 36 | _zero = detach_to_device(0., device) 37 | return _zero, _zero 38 | 39 | # set kls and freebits as lists 40 | if not isinstance(kls, list): 41 | kls = [kls] 42 | 43 | if freebits is not None and not isinstance(freebits, list): 44 | freebits = [freebits for _ in kls] 45 | 46 | # apply freebits to each 47 | if freebits is not None: 48 | kls_loss = (FreeBits(fb)(kl) for fb, kl in zip(freebits, kls)) 49 | else: 50 | kls_loss = kls 51 | 52 | # sum freebit kls 53 | kls_loss = [batch_reduce(kl)[:, None] for kl in kls_loss] 54 | kls_loss = batch_reduce(torch.cat(kls_loss, 1)) 55 | 56 | # sum kls 57 | kls = [batch_reduce(kl)[:, None] for kl in kls] 58 | kls = batch_reduce(torch.cat(kls, 1)) 59 | 60 | return kls, kls_loss 61 | 62 | def compute_elbo(self, x, outputs, beta=1.0, freebits=None, **kwargs): 63 | 64 | # Destructuring dict 65 | x_ = outputs.get('x_') 66 | kls = outputs.get('kl') 67 | 68 | # compute E_p(x) [ - log p_\theta(x | z) ] 69 | nll = - batch_reduce(self.likelihood(logits=x_).log_prob(x)) 70 | 71 | # compute kl: \sum_i E_q(z_i) [ log q(z_i | h) - log p(z_i | h) ] 72 | kl, kls_loss = self.compute_kls(kls, freebits, x.device) 73 | 74 | # compute total loss and elbo 75 | loss = nll + beta * kls_loss 76 | elbo = -(nll + kl) 77 | 78 | # compute auxiliary losses / kls 79 | auxiliary = {} 80 | for k, default_value in self._auxiliary.items(): 81 | # compute value 82 | value = outputs.get(k, None) 83 | if value is not None: 84 | value, _ = self.compute_kls(value, None, x.device) 85 | 86 | # get custom weights from kwargs 87 | weight = kwargs.get(k, default_value) 88 | 89 | # add to loss 90 | loss = loss + weight * value 91 | 92 | # store as a tuple 93 | auxiliary[k] = (weight, value) 94 | 95 | return loss, elbo, kls, kl, nll, auxiliary 96 | 97 | def __call__(self, model: nn.Module, x: Tensor, **kwargs: Any) -> Tuple[Tensor, Dict, Dict]: 98 | """ 99 | Process inputs using model and compute loss, ELBO and diagnostics. 100 | :param model: model to evaluate 101 | :param x: input tensor 102 | :param kwargs: other args passed both to the model and the evaluator 103 | :return: (loss, diagnostics) 104 | """ 105 | 106 | # update kwargs 107 | kwargs.update(self._parameters) 108 | 109 | # importance-weighted placeholders 110 | iw_elbos = torch.zeros((self._iw_samples, x.size(0)), device=x.device, dtype=torch.float) 111 | iw_kls = torch.zeros((self._iw_samples, x.size(0)), device=x.device, dtype=torch.float) 112 | iw_nlls = torch.zeros((self._iw_samples, x.size(0)), device=x.device, dtype=torch.float) 113 | 114 | # Effective Sample size 115 | # w_i = p(x, z_i) / q(z_i | x) 116 | ratios = torch.zeros((self._iw_samples, x.size(0)), device=x.device, dtype=torch.float) 117 | 118 | # feed forward pass 119 | for k in range(self._iw_samples): 120 | # forward pass 121 | outputs = model(x, **kwargs) 122 | 123 | # compute VI elbo 124 | loss, elbo, kls, kl, nll, auxiliary = self.compute_elbo(x, outputs, **kwargs) 125 | iw_elbos[k, :] = elbo 126 | iw_kls[k, :] = kl 127 | iw_nlls[k, :] = nll 128 | ratios[k, :] = elbo.exp() 129 | 130 | if self._iw_samples > 1: 131 | elbo = log_sum_exp(iw_elbos, dim=0, sum_op=torch.mean) 132 | kl = iw_kls.mean(0) 133 | nll = iw_nlls.mean(0) 134 | 135 | # Compute effective sample size 136 | N_eff = torch.sum(ratios, 0) ** 2 / torch.sum(ratios ** 2, 0) 137 | 138 | # gather diagnostics 139 | bits_per_dim = - elbo / math.log(2.) / np.prod(x.size()[1:]) 140 | format = partial(detach_to_device, device=x.device) 141 | diagnostics = { 142 | "loss": {"loss": format(loss), "elbo": format(elbo), "kl": format(kl), "nll": format(nll), 143 | "bpd": format(bits_per_dim)}, 144 | "info": {"N_eff": format(N_eff), "batch_size": x.size(0)} 145 | } 146 | 147 | # add auxiliary 148 | for k, (weight, value) in auxiliary.items(): 149 | diagnostics['loss'][k] = format(value.float().mean()) 150 | 151 | # add kls 152 | diagnostics['kl'] = {f'kl-{i}': v.mean() for i, v in enumerate(kls)} 153 | 154 | # add other params: 155 | def _check_type(v): 156 | return isinstance(v, float) or (isinstance(v, Tensor) and v.dim() == 0) 157 | 158 | diagnostics['parameters'] = {k: v for k, v in kwargs.items() if _check_type(v)} 159 | 160 | diagnostics = Diagnostic(diagnostics).to(x.device) 161 | 162 | return loss.mean(), diagnostics, outputs 163 | -------------------------------------------------------------------------------- /biva/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .block import GatedResNet, ResMLP 2 | from .convolution import PaddedNormedConv, PaddedConv 3 | from .linear import NormedLinear, NormedDense, AsFeatureMap 4 | -------------------------------------------------------------------------------- /biva/layers/block.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | from .convolution import * 4 | from .linear import * 5 | 6 | 7 | class GatedResNet(nn.Module): 8 | def __init__(self, input_shape: Tuple, dim: Tuple, aux_shape: Optional[Tuple] = None, weightnorm: bool = True, 9 | act: nn.Module = nn.ReLU, 10 | transposed: bool = False, dropout: Optional[float] = None, residual: bool = True, **kwargs: Any): 11 | """ 12 | args: 13 | input_shape (tuple): input tensor shape (B x C x *D) 14 | dim (tuple): convolution dimensions (filters, kernel_size, stride) 15 | aux_shape (tuple): auxiliary input tensor shape (B x C x *D). None means no auxialiary input 16 | weightnorm (bool): use weight normalization 17 | act (nn.Module): activation function class 18 | transposed (bool): transposed or not 19 | dropout (float): dropout value. None is no dropout 20 | residual (bool): use residual connections 21 | """ 22 | super(GatedResNet, self).__init__() 23 | 24 | # define convolution and transposed convolution objects 25 | if len(input_shape[2:]) == 1: 26 | conv_obj = nn.Conv1d 27 | deconv_obj = nn.ConvTranspose1d 28 | dropout_obj = nn.Dropout 29 | else: 30 | conv_obj = nn.Conv2d 31 | deconv_obj = nn.ConvTranspose2d 32 | dropout_obj = nn.Dropout 33 | 34 | # some parameters 35 | C_in = input_shape[1] 36 | self.residual = residual 37 | self.transposed = transposed 38 | self.act = act() 39 | 40 | # conv 1 41 | conv1 = conv_obj(C_in, dim[0], dim[1], 1) 42 | self.conv1 = PaddedNormedConv(input_shape, conv1, weightnorm=weightnorm) 43 | shp = self.conv1.output_shape 44 | 45 | # dropout 46 | self.dropout = dropout_obj(dropout) if dropout is not None else dropout 47 | 48 | # conv 2 49 | if self.transposed and dim[2] > 1: 50 | conv2 = deconv_obj(dim[0], 2 * dim[0], dim[1], dim[2]) 51 | else: 52 | conv2 = conv_obj(dim[0], 2 * dim[0], dim[1], dim[2]) 53 | 54 | self.conv2 = PaddedNormedConv(shp, conv2, weightnorm=weightnorm) 55 | 56 | # input / output shapes 57 | shp = list(self.conv2.output_shape) 58 | shp[1] = shp[1] // 2 # gated 59 | self._input_shape = input_shape 60 | self._output_shape = tuple(shp) 61 | self.aux_shape = aux_shape 62 | 63 | # residual connections 64 | self.residual_op = ResidualConnection(self._input_shape, shp, residual) 65 | 66 | # aux op 67 | if aux_shape is not None: 68 | if list(aux_shape[2:]) > list(input_shape[2:]): 69 | stride = tuple(np.asarray(aux_shape[2:]) // np.asarray(input_shape[2:])) 70 | aux_conv = conv_obj(aux_shape[1], dim[0], dim[1], stride) 71 | self.aux_op = PaddedNormedConv(aux_shape, aux_conv, weightnorm=weightnorm) 72 | 73 | elif list(aux_shape[2:]) < list(input_shape[2:]): 74 | stride = tuple(np.asarray(input_shape[2:]) // np.asarray(aux_shape[2:])) 75 | aux_conv = deconv_obj(aux_shape[1], dim[0], dim[1], stride) 76 | self.aux_op = PaddedNormedConv(aux_shape, aux_conv, weightnorm=weightnorm) 77 | 78 | else: 79 | aux_conv = conv_obj(aux_shape[1], dim[0], 1, 1) # conv with kernel 1 80 | self.aux_op = PaddedNormedConv(aux_shape, aux_conv, weightnorm=weightnorm) 81 | 82 | else: 83 | self.aux_op = None 84 | 85 | def forward(self, x: torch.Tensor, aux: Optional[torch.Tensor] = None, **kwargs: Any) -> torch.Tensor: 86 | 87 | # input activation: x = act(x) 88 | x_act = self.act(x) if self.residual else x 89 | 90 | # conv 1: y = conv(x) 91 | y = self.conv1(x_act) 92 | 93 | # merge aux with x: y = y + f(aux) 94 | y = y + self.aux_op(self.act(aux)) if self.aux_op is not None else y 95 | 96 | # y = act(y) 97 | y = self.act(y) 98 | 99 | # dropout 100 | y = self.dropout(y) if self.dropout else y 101 | 102 | # conv 2: y = conv(y) 103 | y = self.conv2(y) 104 | 105 | # gate: y = y_1 * sigmoid(y_2) 106 | h_stack1, h_stack2 = y.chunk(2, 1) 107 | sigmoid_out = torch.sigmoid(h_stack2) 108 | y = (h_stack1 * sigmoid_out) 109 | 110 | # resiudal connection: y = y + x 111 | y = self.residual_op(y, x) 112 | 113 | return y 114 | 115 | @property 116 | def input_shape(self) -> Tuple: 117 | return self._input_shape 118 | 119 | @property 120 | def output_shape(self) -> Tuple: 121 | return self._output_shape 122 | 123 | 124 | class ResMLP(nn.Module): 125 | def __init__(self, input_shape: Tuple[int], dim: int, aux_shape: Optional[int] = None, weightnorm: bool = True, 126 | act: nn.Module = nn.ReLU, transposed: bool = False, dropout: float = None, residual: bool = True, 127 | mlp_layers: int = 1, 128 | **kwargs: Any): 129 | super().__init__() 130 | 131 | # convert parameters 132 | ninp = input_shape[1] 133 | nhid = dim 134 | naux = aux_shape[1] if aux_shape is not None else 0 135 | nlayers = mlp_layers 136 | 137 | # params 138 | self._input_shape = input_shape 139 | self._output_shape = (-1, dim) 140 | 141 | # model 142 | self.act = act() 143 | self.ninp = ninp 144 | self.naux = naux 145 | if naux is None: 146 | naux = 0 147 | self.residual = NormedLinear(ninp, nhid, weightnorm) if ninp != nhid else None 148 | if residual: 149 | layers = [act(), NormedLinear(ninp + naux, nhid, weightnorm), nn.BatchNorm1d(nhid)] 150 | else: 151 | layers = [NormedLinear(ninp + naux, nhid, weightnorm), nn.BatchNorm1d(nhid)] 152 | layers += (nlayers - 1) * [act(), NormedLinear(nhid, nhid, weightnorm), nn.BatchNorm1d(nhid)] 153 | self.layers = nn.Sequential(*layers) 154 | 155 | def forward(self, x, aux=None): 156 | if self.residual is not None: 157 | r = self.residual(x) 158 | else: 159 | r = x 160 | if aux is not None: 161 | x = torch.cat([x, aux], 1) 162 | return r + self.layers(x) 163 | 164 | @property 165 | def input_shape(self): 166 | return self._input_shape 167 | 168 | @property 169 | def output_shape(self): 170 | return self._output_shape 171 | -------------------------------------------------------------------------------- /biva/layers/convolution.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | 9 | def getSAMEPadding(tensor_shape, conv): 10 | """ 11 | return the padding to apply to a given convolution such as it reproduces the 'SAME' behavior from Tensorflow 12 | This works as well for pooling layers. 13 | args: 14 | tensor_shape (tuple): input tensor shape (B x C x D) 15 | conv (nn.Module): convolution object 16 | returns: 17 | sym_padding, unsym_padding: symmetric and unsymmetric padding to apply, 18 | 19 | """ 20 | tensor_shape = np.asarray(tensor_shape)[2:] 21 | kernel_size = np.asarray(conv.kernel_size) 22 | pad = np.asarray(conv.padding) 23 | dilation = np.asarray(conv.dilation) if hasattr(conv, 'dilation') else 1 24 | stride = np.asarray(conv.stride) 25 | # handle pooling layers 26 | if not hasattr(conv, 'transposed'): 27 | conv.transposed = False 28 | else: 29 | assert len(tensor_shape) == len(kernel_size), "tensor is not the same dimension as the kernel" 30 | if not conv.transposed: 31 | effective_filter_size = (kernel_size - 1) * dilation + 1 32 | output_size = (tensor_shape + stride - 1) // stride 33 | padding_input = np.maximum(0, (output_size - 1) * stride + (kernel_size - 1) * dilation + 1 - tensor_shape) 34 | odd_padding = (padding_input % 2 != 0) 35 | sym_padding = tuple(padding_input // 2) 36 | unsym_padding = [y for x in odd_padding for y in [0, int(x)]] 37 | else: 38 | padding_input = kernel_size - stride 39 | sym_padding = None 40 | unsym_padding = [y for x in padding_input for y in 41 | [-int(np.floor(int(x) / 2)), -int(np.floor(int(x) / 2) + int(x) % 2)]] 42 | return sym_padding, unsym_padding 43 | 44 | 45 | class PaddedConv(nn.Module): 46 | """ 47 | wrapps convolution instance with SAME padding (as in tensorflow) 48 | This requires providing the input tensor shape. 49 | """ 50 | 51 | def __init__(self, tensor_shape, conv): 52 | """ 53 | args: 54 | tensor_shape (tuple): input tensor shape 55 | conv (nn.Module): convolution instance 56 | """ 57 | super(PaddedConv, self).__init__() 58 | # input and output shapes 59 | self._input_shp = tensor_shape 60 | self._output_shp = np.asarray(tensor_shape) 61 | stride_factor = np.asarray(conv.stride) 62 | self._output_shp[2:] = self._output_shp[2:] // stride_factor if ( 63 | not hasattr(conv, 'transposed') or not conv.transposed) else self._output_shp[2:] * stride_factor 64 | self._output_shp[1] = conv.out_channels if hasattr(conv, 'out_channels') else self._output_shp[1] 65 | self._output_shp = tuple(self._output_shp.astype(int)) 66 | # get paddings 67 | sym_padding, self.unsym_padding = getSAMEPadding(tensor_shape, conv) 68 | if not conv.transposed: 69 | conv.padding = sym_padding 70 | self.conv = conv 71 | 72 | def forward(self, x): 73 | x = F.pad(x, self.unsym_padding) if not self.conv.transposed else x 74 | x = self.conv(x) 75 | x = F.pad(x, self.unsym_padding) if self.conv.transposed else x 76 | return x 77 | 78 | def init_parameters(self, x): 79 | x = F.pad(x, self.unsym_padding) if not self.conv.transposed else x 80 | self.conv.init_parameters(x) 81 | 82 | @property 83 | def input_shape(self): 84 | return self._input_shp 85 | 86 | @property 87 | def output_shape(self): 88 | return self._output_shp 89 | 90 | 91 | class PaddedNormedConv(nn.Module): 92 | """ 93 | A class to handle both normalization and SAME padding (as in tensorflow) for convolutions. 94 | This class also handles data dependent initialization for weight normalization 95 | """ 96 | 97 | def __init__(self, tensor_shape, conv, weightnorm=True): 98 | """ 99 | args: 100 | tensor_shape (tuple): input tensor shape (B x C x D) 101 | conv (nn.Module): convolution instance of type Conv1d, ConvTranspose1d, Conv2d or ConvTranspose2d 102 | weightnorm (bool): use weight normalization 103 | """ 104 | super(PaddedNormedConv, self).__init__() 105 | self.register_buffer("initialized", torch.tensor(False)) 106 | 107 | # paddding 108 | self._input_shp = tensor_shape 109 | self._output_shp = np.asarray(tensor_shape) 110 | stride_factor = np.asarray(conv.stride) 111 | self._output_shp[2:] = self._output_shp[2:] // stride_factor if ( 112 | not hasattr(conv, 'transposed') or not conv.transposed) else self._output_shp[2:] * stride_factor 113 | self._output_shp[1] = conv.out_channels if hasattr(conv, 'out_channels') else self._output_shp[1] 114 | self._output_shp = tuple(self._output_shp.astype(int)) 115 | 116 | # get paddings 117 | sym_padding, self.unsym_padding = getSAMEPadding(tensor_shape, conv) 118 | if not conv.transposed: 119 | conv.padding = sym_padding 120 | self.conv = conv 121 | 122 | # add batch norm 123 | if not weightnorm: 124 | self.weightnorm = False 125 | else: 126 | self.weightnorm = True 127 | dim = 1 if self.conv.transposed else 0 128 | self.conv = nn.utils.weight_norm(self.conv, dim=dim, name="weight") 129 | 130 | def forward(self, x): 131 | x = F.pad(x, self.unsym_padding) if not self.conv.transposed else x 132 | if not self.initialized: 133 | self.init_parameters(x) 134 | x = self.conv(x) 135 | 136 | x = F.pad(x, self.unsym_padding) if self.conv.transposed else x 137 | return x 138 | 139 | @property 140 | def input_shape(self): 141 | return self._input_shp 142 | 143 | @property 144 | def output_shape(self): 145 | return self._output_shp 146 | 147 | def init_parameters(self, x, init_scale=0.05, eps=1e-8): 148 | self.initialized = True + self.initialized 149 | if self.weightnorm: 150 | # initial values 151 | self.conv._parameters['weight_v'].data.normal_(mean=0, std=init_scale) 152 | self.conv._parameters['weight_g'].data.fill_(1.) 153 | self.conv._parameters['bias'].data.fill_(0.) 154 | init_scale = .01 155 | # data dependent init 156 | x = self.conv(x) 157 | t = x.view(x.size()[0], x.size()[1], -1) 158 | t = t.permute(0, 2, 1).contiguous() 159 | t = t.view(-1, t.size()[-1]) 160 | m_init, v_init = torch.mean(t, 0), torch.var(t, 0) 161 | scale_init = init_scale / torch.sqrt(v_init + eps) 162 | if self.conv.transposed: 163 | self.conv._parameters['weight_g'].data = self.conv._parameters['weight_g'].data * scale_init[None, 164 | :].view( 165 | self.conv._parameters['weight_g'].data.size()) 166 | self.conv._parameters['bias'].data = self.conv._parameters['bias'].data - m_init * scale_init 167 | else: 168 | self.conv._parameters['weight_g'].data = self.conv._parameters['weight_g'].data * scale_init[:, 169 | None].view( 170 | self.conv._parameters['weight_g'].data.size()) 171 | self.conv._parameters['bias'].data = self.conv._parameters['bias'].data - m_init * scale_init 172 | return scale_init[None, :, None, None] * (x - m_init[None, :, None, None]) if len( 173 | self._input_shp) > 3 else scale_init[None, :, None] * (x - m_init[None, :, None]) 174 | 175 | 176 | class ResidualConnection(nn.Module): 177 | """ 178 | Handles residual connections for tensors with different shapes. 179 | Apply padding and/or avg pooling to the input when necessary 180 | """ 181 | 182 | def __init__(self, input_shape, output_shape, residual=True): 183 | """ 184 | args: 185 | input_shape (tuple): input module shape x 186 | output_shape (tuple): output module shape y=f(x) 187 | residual (bool): apply residual conenction y' = y+x = f(x)+x 188 | """ 189 | super().__init__() 190 | self.residual = residual 191 | self.input_shape = input_shape 192 | self.output_shape = output_shape 193 | is_text = len(input_shape) == 3 194 | 195 | # residual: features 196 | if residual and self.output_shape[1] < self.input_shape[1]: 197 | pad = int(self.output_shape[1]) - int(self.input_shape[1]) 198 | self.redidual_padding = [0, 0, 0, pad] if is_text else [0, 0, 0, 0, 0, pad] 199 | 200 | elif residual and self.output_shape[1] > self.input_shape[1]: 201 | pad = int(self.output_shape[1]) - int(self.input_shape[1]) 202 | self.redidual_padding = [0, 0, 0, pad] if is_text else [0, 0, 0, 0, 0, pad] 203 | warnings.warn("The input has more feature maps than the output. There will be no residual connection for this layer.") 204 | self.residual = False 205 | else: 206 | self.redidual_padding = None 207 | 208 | # residual: dimension 209 | if residual and list(output_shape)[2:] < list(input_shape)[2:]: 210 | pool_obj = nn.AvgPool1d if len(output_shape[2:]) == 1 else nn.AvgPool2d 211 | stride = tuple((np.asarray(input_shape)[2:] // np.asarray(output_shape)[2:]).tolist()) 212 | self.residual_op = PaddedConv(input_shape, pool_obj(3, stride=stride)) 213 | 214 | elif residual and list(output_shape)[2:] > list(input_shape)[2:]: 215 | warnings.warn( 216 | "The height and width of the output are larger than the input. There will be no residual connection for this layer.") 217 | # self.residual_op = nn.UpsamplingBilinear2d(size=self.output_shape[2:]) 218 | self.residual = False 219 | else: 220 | self.residual_op = None 221 | 222 | def forward(self, y, x): 223 | if self.residual: 224 | x = F.pad(x, self.redidual_padding) if self.redidual_padding is not None else x 225 | x = self.residual_op(x) if self.residual_op is not None else x 226 | y = y + x 227 | return y 228 | 229 | 230 | def getConvolutionOutputShape(tensor_shape, conv): 231 | """ 232 | compute the output shape of a convolution given the input tensor shape 233 | args: 234 | tensor_shape (tuple): input tensor shape (B x C x D) 235 | conv (nn.Module): convolution object 236 | returns: 237 | output_shape (tuple): expected output shape 238 | """ 239 | assert tensor_shape[1] == conv.in_channels, "tensor and kernel do not have the same nuæber of features" 240 | tensor_shape = np.asarray(tensor_shape) 241 | kernel_size = np.asarray(conv.kernel_size) 242 | pad = np.asarray(conv.padding) 243 | dilation = np.asarray(conv.dilation) 244 | stride = np.asarray(conv.stride) 245 | assert len(tensor_shape) - 2 == len(kernel_size), "tensor is not the same dimension as the kernel" 246 | dims = tensor_shape[2:] 247 | if not conv.transposed: 248 | out_dims = ((dims + (2 * pad) - (dilation * (kernel_size - 1)) - 1) / stride) + 1 249 | else: 250 | out_padding = np.asarray(conv.output_padding) 251 | out_dims = dilation * ((dims - 1) * stride - 2 * padding + kernel_size) + out_padding 252 | out_dims = np.floor(out_dims).astype(int) 253 | return (tensor_shape[0], conv.out_channels, *out_dims) 254 | 255 | 256 | if __name__ == '__main__': 257 | # test getConvolutionOutputShape 258 | for x in [torch.zeros(1, 3, 33, 33), torch.zeros(1, 3, 64, 64)]: 259 | for kernel_size in [3, 5, 6, 8]: 260 | for padding in [1, 2]: 261 | for stride in [1, 2, 3]: 262 | for dilation in [1]: 263 | for Conv in [nn.Conv2d, nn.ConvTranspose2d]: 264 | conv = Conv(3, 18, kernel_size=kernel_size, padding=padding, stride=stride, 265 | dilation=dilation) 266 | x_shp = tuple(x.size()) 267 | c_shp = getConvolutionOutputShape(x_shp, conv) 268 | true_shp = tuple(conv(x).size()) 269 | assert c_shp[2:] == true_shp[2:] 270 | 271 | # test getPadding 272 | for x in [torch.zeros(1, 3, 33, 32), torch.zeros(1, 3, 15, 15)]: 273 | for Conv in [nn.ConvTranspose2d, nn.Conv2d]: 274 | for stride in [2, 1]: 275 | for dilation in [1, 2]: 276 | for kernel_size in [3, 5, 6]: 277 | conv = Conv(3, 5, kernel_size=kernel_size, stride=stride, dilation=dilation) 278 | # print("kernel_size:",kernel_size, " stride:", stride," trans.:",conv.transposed, " x:", x.size()) 279 | x_shp = tuple(x.size()) 280 | padding, odd_padding = getSAMEPadding(x_shp, conv) 281 | if not conv.transposed: 282 | conv.padding = padding 283 | x_shp = x_shp[2:] 284 | expected_shp = tuple([t * stride for t in x_shp]) if conv.transposed else tuple( 285 | [t // stride for t in x_shp]) 286 | y = F.pad(x, odd_padding) if not conv.transposed else x 287 | y = conv(y) 288 | y = F.pad(y, odd_padding) if conv.transposed else y 289 | true_shp = tuple(y.size())[2:] 290 | # print(true_shp,expected_shp,np.asarray(true_shp)-np.asarray(expected_shp),kernel_size-stride,'\n') 291 | assert true_shp >= expected_shp 292 | 293 | # test SAMEpaddingConv 294 | for x in [torch.zeros(1, 3, 33, 32), torch.zeros(1, 3, 15, 15)]: 295 | for Conv in [nn.ConvTranspose2d, nn.Conv2d]: 296 | for stride in [2, 1]: 297 | for dilation in [1, 2]: 298 | for kernel_size in [3, 5, 6]: 299 | conv = Conv(3, 5, kernel_size=kernel_size, stride=stride, dilation=dilation) 300 | # print("kernel_size:",kernel_size, " stride:", stride," trans.:",conv.transposed, " x:", x.size()) 301 | x_shp = tuple(x.size()) 302 | conv = SAMEpaddingConv(x_shp, conv) 303 | x_shp = x_shp[2:] 304 | expected_shp = tuple([t * stride for t in x_shp]) if conv.conv.transposed else tuple( 305 | [t // stride for t in x_shp]) 306 | y = conv(x) 307 | true_shp = tuple(y.size())[2:] 308 | # print(true_shp,expected_shp,np.asarray(true_shp)-np.asarray(expected_shp),kernel_size-stride,'\n') 309 | assert true_shp >= expected_shp 310 | 311 | # test SAMEpaddingConv 312 | for x in [torch.zeros(1, 3, 33, 32), torch.zeros(1, 3, 15, 15)]: 313 | for Conv in [nn.MaxPool2d]: 314 | for stride in [2, 1]: 315 | for dilation in [1]: 316 | for kernel_size in [3, 5]: 317 | conv = Conv(kernel_size=kernel_size, stride=stride, dilation=dilation) 318 | # print("kernel_size:",kernel_size, " stride:", stride," trans.:",conv.transposed, " x:", x.size()) 319 | x_shp = tuple(x.size()) 320 | conv = SAMEpaddingConv(x_shp, conv) 321 | x_shp = x_shp[2:] 322 | expected_shp = tuple([t * stride for t in x_shp]) if conv.conv.transposed else tuple( 323 | [t // stride for t in x_shp]) 324 | y = conv(x) 325 | true_shp = tuple(y.size())[2:] 326 | # print(true_shp,expected_shp,np.asarray(true_shp)-np.asarray(expected_shp),kernel_size-stride,'\n') 327 | assert true_shp >= expected_shp 328 | -------------------------------------------------------------------------------- /biva/layers/linear.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class NormedLinear(nn.Module): 7 | """ 8 | Linear layer with normalization 9 | """ 10 | 11 | def __init__(self, in_features, out_features, dim=-1, weightnorm=True): 12 | super(NormedLinear, self).__init__() 13 | """ 14 | args: 15 | in_features (in): number of input features 16 | out_features (int): number of output features 17 | dim (int): dimension to aply transformation to 18 | weightnorm (bool): use weight normalization 19 | """ 20 | self.register_buffer("initialized", torch.tensor(False)) 21 | self._in_features = in_features 22 | self._out_features = out_features 23 | self.dim = dim 24 | self.linear = nn.Linear(self._in_features, out_features) 25 | # add batch norm 26 | if not weightnorm: 27 | self.weightnorm = False 28 | else: 29 | self.weightnorm = True 30 | self.linear = nn.utils.weight_norm(self.linear, dim=0, name="weight") 31 | 32 | def forward(self, x): 33 | # reshape in 34 | shp = list(x.size()) 35 | dim = self.dim if self.dim >= 0 else x.dim() + self.dim 36 | if dim < x.dim() - 1: 37 | x = x.view(*x.size()[:dim + 1], -1) 38 | x = x.transpose(-1, -2).contiguous() 39 | shp_2 = list(x.shape) 40 | x = x.view(-1, x.size(-1)) 41 | permute = True 42 | else: 43 | x = x.view(-1, x.size(-1)) 44 | permute = False 45 | # init and transform 46 | if not self.initialized: 47 | self.init_parameters(x) 48 | x = self.linear(x) 49 | # reshape out 50 | shp[dim] = self._out_features 51 | if permute: 52 | shp_2[-1] = self._out_features 53 | x = x.view(shp_2).transpose(-1, -2) 54 | x = x.view(shp) 55 | else: 56 | x = x.view(shp) 57 | return x 58 | 59 | @property 60 | def input_shape(self): 61 | return (-1, self._in_features) 62 | 63 | @property 64 | def output_shape(self): 65 | return (-1, self._out_features) 66 | 67 | def init_parameters(self, x, init_scale=0.05, eps=1e-8): 68 | if self.weightnorm: 69 | # initial values 70 | self.linear._parameters['weight_v'].data.normal_(mean=0, std=init_scale) 71 | self.linear._parameters['weight_g'].data.fill_(1.) 72 | self.linear._parameters['bias'].data.fill_(0.) 73 | init_scale = .01 74 | # data dependent init 75 | x = self.linear(x) 76 | m_init, v_init = torch.mean(x, 0), torch.var(x, 0) 77 | scale_init = init_scale / torch.sqrt(v_init + eps) 78 | self.linear._parameters['weight_g'].data = self.linear._parameters['weight_g'].data * scale_init.view( 79 | self.linear._parameters['weight_g'].data.size()) 80 | self.linear._parameters['bias'].data = self.linear._parameters['bias'].data - m_init * scale_init 81 | self.initialized = True + self.initialized 82 | return scale_init[None, :] * (x - m_init[None, :]) 83 | 84 | 85 | class NormedDense(nn.Module): 86 | """ 87 | Dense layer with normalization 88 | """ 89 | 90 | def __init__(self, tensor_shape, out_features, weightnorm=True): 91 | super(NormedDense, self).__init__() 92 | """ 93 | args: 94 | tensor_shape (tuple): input tensor shape (B x C x D) 95 | out_features (int): number of output features 96 | weight (bool): use weight normalization 97 | """ 98 | self.register_buffer("initialized", torch.tensor(False)) 99 | self._input_shp = tensor_shape 100 | self.input_features = int(np.prod(tensor_shape[1:])) 101 | self._output_shp = (-1, out_features) 102 | self.linear = nn.Linear(self.input_features, out_features) 103 | # add batch norm 104 | if not weightnorm: 105 | self.weightnorm = False 106 | else: 107 | self.weightnorm = True 108 | self.linear = nn.utils.weight_norm(self.linear, dim=0, name="weight") 109 | 110 | def forward(self, x): 111 | x = x.view(x.size()[0], -1) 112 | if not self.initialized: 113 | self.init_parameters(x) 114 | x = self.linear(x) 115 | return x 116 | 117 | @property 118 | def input_shape(self): 119 | return self._input_shp 120 | 121 | @property 122 | def output_shape(self): 123 | return self._output_shp 124 | 125 | def init_parameters(self, x, init_scale=0.05, eps=1e-8): 126 | if self.weightnorm: 127 | # initial values 128 | self.linear._parameters['weight_v'].data.normal_(mean=0, std=init_scale) 129 | self.linear._parameters['weight_g'].data.fill_(1.) 130 | self.linear._parameters['bias'].data.fill_(0.) 131 | init_scale = .01 132 | # data dependent init 133 | x = self.linear(x) 134 | m_init, v_init = torch.mean(x, 0), torch.var(x, 0) 135 | scale_init = init_scale / torch.sqrt(v_init + eps) 136 | self.linear._parameters['weight_g'].data = self.linear._parameters['weight_g'].data * scale_init.view( 137 | self.linear._parameters['weight_g'].data.size()) 138 | self.linear._parameters['bias'].data = self.linear._parameters['bias'].data - m_init * scale_init 139 | self.initialized = True + self.initialized 140 | return scale_init[None, :] * (x - m_init[None, :]) 141 | 142 | 143 | class AsFeatureMap(nn.Module): 144 | def __init__(self, input_shape, target_shape, weightnorm=True, **kwargs): 145 | super().__init__() 146 | 147 | self._input_shp = input_shape 148 | 149 | if len(input_shape) < len(target_shape): 150 | out_features = np.prod(target_shape[1:]) 151 | self.linear = NormedDense(input_shape, out_features, weightnorm=weightnorm) 152 | self._output_shp = target_shape 153 | 154 | else: 155 | self.linear = None 156 | self._output_shp = input_shape 157 | 158 | def forward(self, x): 159 | if self.linear is None: 160 | return x 161 | 162 | x = self.linear(x) 163 | return x.view(self.output_shape) 164 | 165 | @property 166 | def input_shape(self): 167 | return self._input_shp 168 | 169 | @property 170 | def output_shape(self): 171 | return self._output_shp 172 | -------------------------------------------------------------------------------- /biva/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .architectures import get_deep_vae_mnist, get_deep_vae_cifar 2 | from .deepvae import DeepVae, VAE, LVAE, BIVA 3 | from .stage import BaseStage, VaeStage, LvaeStage, BivaStage 4 | from .stochastic import StochasticLayer, DenseNormal, ConvNormal 5 | -------------------------------------------------------------------------------- /biva/model/architectures.py: -------------------------------------------------------------------------------- 1 | from .stochastic import DenseNormal, ConvNormal 2 | 3 | 4 | def get_deep_vae_mnist(): 5 | """ 6 | Get the binary images Deep VAE configuration. 7 | :return: enc, z 8 | """ 9 | filters = 64 10 | no_layers = 2 11 | enc = [] 12 | z = [] 13 | 14 | enc_z1 = [[filters, 5, 1]] * no_layers 15 | enc_z1 += [[filters, 5, 2]] 16 | z1 = {'N': 48, 'block': DenseNormal} 17 | enc += [enc_z1] 18 | z += [z1] 19 | 20 | enc_z2 = [[filters, 3, 1]] * no_layers 21 | enc_z2 += [[filters, 3, 1]] 22 | z2 = {'N': 40, 'block': DenseNormal} 23 | enc += [enc_z2] 24 | z += [z2] 25 | 26 | enc_z3 = [[filters, 3, 1]] * no_layers 27 | enc_z3 += [[filters, 3, 1]] 28 | z3 = {'N': 32, 'block': DenseNormal} 29 | enc += [enc_z3] 30 | z += [z3] 31 | 32 | enc_z4 = [[filters, 3, 1]] * no_layers 33 | enc_z4 += [[filters, 3, 1]] 34 | z4 = {'N': 24, 'block': DenseNormal} 35 | enc += [enc_z4] 36 | z += [z4] 37 | 38 | enc_z5 = [[filters, 3, 1]] * no_layers 39 | enc_z5 += [[filters, 3, 1]] 40 | z5 = {'N': 16, 'block': DenseNormal} 41 | enc += [enc_z5] 42 | z += [z5] 43 | 44 | enc_z6 = [[filters, 3, 1]] * no_layers 45 | enc_z6 += [[filters, 3, 2]] 46 | z6 = {'N': 8, 'block': DenseNormal} 47 | enc += [enc_z6] 48 | z += [z6] 49 | 50 | return enc, z 51 | 52 | 53 | def get_deep_vae_cifar(): 54 | filters = 96 55 | no_layers = 2 56 | enc = [] 57 | z = [] 58 | 59 | enc_z1 = [[filters, 5, 1]] * no_layers 60 | enc_z1 += [[filters, 5, 2]] 61 | z_1 = {'N': 38, 'kernel': 16, 'block': ConvNormal} 62 | enc += [enc_z1] 63 | z += [z_1] 64 | 65 | enc_z2 = [[filters, 3, 1]] * no_layers 66 | enc_z2 += [[filters, 3, 1]] 67 | z_2 = {'N': 36, 'kernel': 16, 'block': ConvNormal} 68 | enc += [enc_z2] 69 | z += [z_2] 70 | 71 | enc_z3 = [[filters, 3, 1]] * no_layers 72 | enc_z3 += [[filters, 3, 1]] 73 | z_3 = {'N': 34, 'kernel': 16, 'block': ConvNormal} 74 | enc += [enc_z3] 75 | z += [z_3] 76 | 77 | enc_z4 = [[filters, 3, 1]] * no_layers 78 | enc_z4 += [[filters, 3, 1]] 79 | z_4 = {'N': 32, 'kernel': 16, 'block': ConvNormal} 80 | enc += [enc_z4] 81 | z += [z_4] 82 | 83 | enc_z5 = [[filters, 3, 1]] * no_layers 84 | enc_z5 += [[filters, 3, 1]] 85 | z_5 = {'N': 30, 'kernel': 16, 'block': ConvNormal} 86 | enc += [enc_z5] 87 | z += [z_5] 88 | 89 | enc_z6 = [[filters, 3, 1]] * no_layers 90 | enc_z6 += [[filters, 3, 1]] 91 | z_6 = {'N': 28, 'kernel': 16, 'block': ConvNormal} 92 | enc += [enc_z6] 93 | z += [z_6] 94 | 95 | enc_z7 = [[filters, 3, 1]] * no_layers 96 | enc_z7 += [[filters, 3, 1]] 97 | z_7 = {'N': 26, 'kernel': 16, 'block': ConvNormal} 98 | enc += [enc_z7] 99 | z += [z_7] 100 | 101 | enc_z8 = [[filters, 3, 1]] * no_layers 102 | enc_z8 += [[filters, 3, 1]] 103 | z_8 = {'N': 24, 'kernel': 16, 'block': ConvNormal} 104 | enc += [enc_z8] 105 | z += [z_8] 106 | 107 | enc_z9 = [[filters, 3, 1]] * no_layers 108 | enc_z9 += [[filters, 3, 1]] 109 | z_9 = {'N': 22, 'kernel': 16, 'block': ConvNormal} 110 | enc += [enc_z9] 111 | z += [z_9] 112 | 113 | enc_z10 = [[filters, 3, 1]] * no_layers 114 | enc_z10 += [[filters, 3, 1]] 115 | z_10 = {'N': 20, 'kernel': 16, 'block': ConvNormal} 116 | enc += [enc_z10] 117 | z += [z_10] 118 | 119 | enc_z11 = [[filters, 3, 1]] * no_layers 120 | enc_z11 += [[filters, 3, 2]] 121 | z_11 = {'N': 18, 'kernel': 8, 'block': ConvNormal} 122 | enc += [enc_z11] 123 | z += [z_11] 124 | 125 | enc_z12 = [[filters, 3, 1]] * no_layers 126 | enc_z12 += [[filters, 3, 1]] 127 | z_12 = {'N': 16, 'kernel': 8, 'block': ConvNormal} 128 | enc += [enc_z12] 129 | z += [z_12] 130 | 131 | enc_z13 = [[filters, 3, 1]] * no_layers 132 | enc_z13 += [[filters, 3, 1]] 133 | z_13 = {'N': 14, 'kernel': 8, 'block': ConvNormal} 134 | enc += [enc_z13] 135 | z += [z_13] 136 | 137 | enc_z14 = [[filters, 3, 1]] * no_layers 138 | enc_z14 += [[filters, 3, 1]] 139 | z_14 = {'N': 12, 'kernel': 8, 'block': ConvNormal} 140 | enc += [enc_z14] 141 | z += [z_14] 142 | 143 | enc_z15 = [[filters, 3, 1]] * no_layers 144 | enc_z15 += [[filters, 3, 2]] 145 | z_15 = {'N': 10, 'kernel': 4, 'block': ConvNormal} 146 | enc += [enc_z15] 147 | z += [z_15] 148 | 149 | return enc, z 150 | -------------------------------------------------------------------------------- /biva/model/deepvae.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from .architectures import get_deep_vae_mnist 7 | from .stage import VaeStage, LvaeStage, BivaStage 8 | from .utils import DataCollector 9 | from ..layers import PaddedNormedConv 10 | 11 | 12 | class DeepVae(nn.Module): 13 | """ 14 | A Deep Hierarchical VAE. 15 | The model is a stack of N stages. Each stage features an inference and a generative path. 16 | Depending on the choice of the stage, multiple models can be implemented: 17 | - VAE: https://arxiv.org/abs/1312.6114 18 | - LVAE: https://arxiv.org/abs/1602.02282 19 | - BIVA: https://arxiv.org/abs/1902.02102 20 | """ 21 | 22 | def __init__(self, 23 | Stage: Any = BivaStage, 24 | tensor_shp: Tuple[int] = (-1, 1, 28, 28), 25 | padded_shp: Optional[Tuple] = None, 26 | stages: List[List[Tuple]] = None, 27 | latents: List = None, 28 | nonlinearity: str = 'elu', 29 | q_dropout: float = 0., 30 | p_dropout: float = 0., 31 | features_out: Optional[int] = None, 32 | lambda_init: Optional[Callable] = None, 33 | projection: Optional[nn.Module] = None, 34 | **kwargs): 35 | 36 | """ 37 | Initialize the Deep VAE model. 38 | :param Stage: stage constructor (VaeStage, LvaeStage, BivaStage) 39 | :param tensor_shp: Input tensor shape (batch_size, channels, *dimensions) 40 | :param padded_shp: pad input tensor to this shape 41 | :param stages: a list of list of tuple, each tuple describing a convolutional block (filters, stride, kernel_size) 42 | :param latents: a list describing the stochastic layers for each stage 43 | :param nonlinearity: activation function (gelu, elu, relu, tanh) 44 | :param q_dropout: inference dropout value 45 | :param p_dropout: generative dropout value 46 | :param features_out: optional number of output features if different from the input 47 | :param lambda_init: lambda function applied to the input 48 | :param projection: projection layer with constructor __init__(output_shape) 49 | 50 | :param kwargs: additional arugments passed to each stage 51 | """ 52 | super().__init__() 53 | stages, latents = self.get_default_architecture(stages, latents) 54 | 55 | self.input_tensor_shape = tensor_shp 56 | self.lambda_init = lambda_init 57 | 58 | # input padding 59 | if padded_shp is not None: 60 | padding = [[(t - o) // 2, (t - o) // 2] for t, o in zip(padded_shp, tensor_shp[2:])] 61 | self.pad = [u for pads in padding for u in pads] 62 | self.unpad = [-u for u in self.pad] 63 | in_shp = [*tensor_shp[:2], *padded_shp] 64 | else: 65 | self.pad = None 66 | in_shp = tensor_shp 67 | 68 | # select activation class 69 | Act = {'elu': nn.ELU, 'relu': nn.ReLU, 'tanh': nn.Tanh()}[nonlinearity] 70 | 71 | # initialize the inference path 72 | stages_ = [] 73 | block_args = {'act': Act, 'q_dropout': q_dropout, 'p_dropout': p_dropout} 74 | 75 | input_shape = {'x': in_shp} 76 | for i, (conv_data, z_data) in enumerate(zip(stages, latents)): 77 | top = i == len(stages) - 1 78 | bottom = i == 0 79 | 80 | stage = Stage(input_shape, conv_data, z_data, top=top, bottom=bottom, **block_args, **kwargs) 81 | 82 | input_shape = stage.q_output_shape 83 | stages_ += [stage] 84 | 85 | self.stages = nn.ModuleList(stages_) 86 | 87 | if projection is None: 88 | # output convolution 89 | tensor_shp = self.stages[0].p_output_shape['d'] 90 | if features_out is None: 91 | features_out = self.input_tensor_shape[1] 92 | conv_obj = nn.Conv2d if len(tensor_shp) == 4 else nn.Conv1d 93 | conv_out = conv_obj(tensor_shp[1], features_out, 1) 94 | conv_out = PaddedNormedConv(tensor_shp, conv_out, weightnorm=True) 95 | self.projection = nn.Sequential(Act(), conv_out) 96 | else: 97 | tensor_shp = self.stages[0].forward_shape['d'] 98 | self.projection = projection(tensor_shp) 99 | 100 | def get_default_architecture(self, stages, latents): 101 | if stages is None: 102 | stages, _ = get_deep_vae_mnist() 103 | 104 | if latents is None: 105 | _, latens = get_deep_vae_mnist() 106 | 107 | return stages, latents 108 | 109 | def infer(self, x: torch.Tensor, **kwargs: Any) -> List[Dict]: 110 | """ 111 | Forward pass through the inference network and return the posterior of each layer order from the top to the bottom. 112 | :param x: input tensor 113 | :param kwargs: additional arguments passed to each stage 114 | :return: a list that contains the data for each stage 115 | """ 116 | posteriors = [] 117 | data = {'x': x} 118 | for stage in self.stages: 119 | data, posterior = stage.infer(data, **kwargs) 120 | posteriors += [posterior] 121 | 122 | return posteriors 123 | 124 | def generate(self, posteriors: Optional[List], **kwargs) -> Dict[str, torch.Tensor]: 125 | """ 126 | Forward pass through the generative model, compute KL and return reconstruction x_, KL and auxiliary data. 127 | If no posterior is provided, the prior is sampled. 128 | :param posteriors: a list containing the posterior for each stage 129 | :param kwargs: additional arguments passed to each stage 130 | :return: {'x_': reconstruction logits, 'kl': kl for each stage, **auxiliary} 131 | """ 132 | if posteriors is None: 133 | posteriors = [None for _ in self.stages] 134 | 135 | output_data = DataCollector() 136 | x = {} 137 | for posterior, stage in zip(posteriors[::-1], self.stages[::-1]): 138 | x, data = stage(x, posterior, **kwargs) 139 | output_data.extend(data) 140 | 141 | # output convolution 142 | x = self.projection(x['d']) 143 | 144 | # undo padding 145 | if self.pad is not None: 146 | x = nn.functional.pad(x, self.unpad) 147 | 148 | # sort data: [z1, z2, ..., z_L] 149 | output_data = output_data.sort() 150 | 151 | return {'x_': x, **output_data} 152 | 153 | def forward(self, x: torch.Tensor, **kwargs: Any) -> Dict[str, torch.Tensor]: 154 | """ 155 | Forward pass through the inference model, the generative model and compute KL for each stage. 156 | x_ = p_\theta(x|z), z \sim q_\phi(z|x) 157 | kl_i = log q_\phi(z_i | h) - log p_\theta(z_i | h) 158 | 159 | :param x: input tensor 160 | :param kwargs: additional arguments passed to each stage 161 | :return: {'x_': reconstruction logits, 'kl': kl for each stage, **auxiliary} 162 | """ 163 | 164 | if self.pad is not None: 165 | x = nn.functional.pad(x, self.pad) 166 | 167 | if self.lambda_init is not None: 168 | x = self.lambda_init(x) 169 | 170 | posteriors = self.infer(x, **kwargs) 171 | 172 | data = self.generate(posteriors, N=x.size(0), **kwargs) 173 | 174 | return data 175 | 176 | def sample_from_prior(self, N: int, **kwargs: Any) -> Dict[str, torch.Tensor]: 177 | """ 178 | Sample the prior and pass through the generative model. 179 | x_ = p_\theta(x|z), z \sim p_\theta(z) 180 | 181 | :param N: number of samples (batch size) 182 | :param kwargs: additional arguments passed to each stage 183 | :return: {'x_': sample logits} 184 | """ 185 | return self.generate(None, N=N, **kwargs) 186 | 187 | 188 | class BIVA(DeepVae): 189 | def __init__(self, **kwargs): 190 | kwargs.pop('Stage', None) 191 | super().__init__(Stage=BivaStage, **kwargs) 192 | 193 | 194 | class LVAE(DeepVae): 195 | def __init__(self, **kwargs): 196 | kwargs.pop('Stage', None) 197 | super().__init__(Stage=LvaeStage, **kwargs) 198 | 199 | 200 | class VAE(DeepVae): 201 | def __init__(self, **kwargs): 202 | kwargs.pop('Stage', None) 203 | super().__init__(Stage=VaeStage, **kwargs) 204 | -------------------------------------------------------------------------------- /biva/model/stage.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | from typing import * 3 | 4 | import torch 5 | from torch import nn, Tensor 6 | 7 | from .stochastic import StochasticLayer 8 | from .utils import DataCollector 9 | from ..layers import GatedResNet, AsFeatureMap 10 | from ..utils import shp_cat 11 | 12 | 13 | def StochasticBlock(data: Dict, *args, **kwargs): 14 | """Construct the stochastic block given by the Block argument""" 15 | Block = data.get('block') 16 | block = Block(data, *args, **kwargs) 17 | assert isinstance(block, StochasticLayer) 18 | return block 19 | 20 | 21 | class DeterministicBlocks(nn.Module): 22 | 23 | def __init__(self, 24 | tensor_shp: Tuple[int], 25 | convolutions: List[Tuple[int]], 26 | in_residual: bool = True, 27 | transposed: bool = False, 28 | Block: Any = GatedResNet, 29 | aux_shape: Optional[List[Tuple[int]]] = None, 30 | **kwargs): 31 | """ 32 | Defines a of sequence of deterministic blocks (resnets). 33 | You can extend this class by passing other Block classes as an argument. 34 | 35 | auxiliary connections: if the number of auxiliary inputs is smaller than the number of layers, 36 | the auxiliary inputs are repeated to match the number of layers. 37 | 38 | :param tensor_shp: input tensor shape as a tuple of integers (B, H, *D) 39 | :param convolutions: describes the sequence of blocks, each of them defined by a tuple (filters, kernel_size, stride) 40 | :param aux_shape: auxiliary input tensor shape as a tuple of integers (B, H, *D) 41 | :param transposed: use transposed convolutions 42 | :param residual: use residual connections 43 | :param Block: Block object constructor (GatedResNet, ResMLP) 44 | """ 45 | super().__init__() 46 | self.input_shape = tensor_shp 47 | self._use_skips = True 48 | layers = [] 49 | 50 | if aux_shape is None: 51 | self._use_skips = False 52 | aux_shape = [] 53 | 54 | for j, dim in enumerate(convolutions): 55 | residual = True if j > 0 else in_residual 56 | aux = aux_shape.pop() if self._use_skips else None 57 | block = Block(tensor_shp, dim, aux_shape=aux, transposed=transposed, residual=residual, 58 | **kwargs) 59 | tensor_shp = block.output_shape 60 | aux_shape = [tensor_shp] + aux_shape 61 | layers += [block] 62 | 63 | self.layers = nn.ModuleList(layers) 64 | self.output_shape = tensor_shp 65 | self.hidden_shapes = aux_shape 66 | 67 | def __len__(self): 68 | return len(self.layers) 69 | 70 | def forward(self, x: Tensor, aux: Optional[List[Tensor]] = None, **kwargs) -> Tuple[Tensor, List[Tensor]]: 71 | """ 72 | :param x: input tensor 73 | :param aux: list of auxiliary inputs 74 | :return: output tensor, activations 75 | """ 76 | if aux is None: 77 | aux = [] 78 | 79 | for layer in self.layers: 80 | a = aux.pop() if self._use_skips else None 81 | x = layer(x, a, **kwargs) 82 | aux = [x] + aux 83 | 84 | return x, aux 85 | 86 | 87 | class BaseStage(nn.Module): 88 | def __init__(self, 89 | input_shape: Dict[str, Tuple[int]], 90 | convolutions: List[Tuple[int]], 91 | stochastic: Tuple, 92 | top: bool = False, 93 | bottom: bool = False, 94 | q_dropout: float = 0, 95 | p_dropout: float = 0, 96 | Block: Any = GatedResNet, 97 | no_skip: bool = False, 98 | **kwargs): 99 | """ 100 | Define a stage of a hierarchical model. 101 | In a VAE setting, a stage defines: 102 | * the latent variable z_i 103 | * the encoder q(z_i | h_{q Dict[str, Tuple[int]]: 116 | """size of the input tensors for the inference path""" 117 | return self._input_shape 118 | 119 | @property 120 | def q_output_shape(self) -> Dict[str, Tuple[int]]: 121 | """size of the output tensors for the inference path""" 122 | raise NotImplementedError 123 | 124 | @property 125 | def forward_shape(self) -> Tuple[int]: 126 | """size of the output tensor for the generative path""" 127 | raise NotImplementedError 128 | 129 | def infer(self, data: Dict[str, Tensor], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]: 130 | """ 131 | Perform a forward pass through the inference layers and sample the posterior. 132 | 133 | :param data: input data 134 | :param kwargs: additional parameters passed to the stochastic layer 135 | :return: (output data, variational data) 136 | """ 137 | raise NotImplementedError 138 | 139 | def forward(self, data: dict, posterior: Optional[dict], **kwargs) -> Tuple[ 140 | Dict, Dict[str, List]]: 141 | """ 142 | Perform a forward pass through the generative model and compute KL if posterior data is available 143 | 144 | :param data: data from the above stage forward pass 145 | :param posterior: dictionary representing the posterior from same stage inference pass 146 | :return: (dict('d' : d, 'aux : [aux]), dict('kl': [kl], **auxiliary) ) 147 | """ 148 | raise NotImplementedError 149 | 150 | 151 | class VaeStage(BaseStage): 152 | def __init__(self, 153 | input_shape: Dict[str, Tuple[int]], 154 | convolutions: List[Tuple[int]], 155 | stochastic: Tuple, 156 | top: bool = False, 157 | bottom: bool = False, 158 | q_dropout: float = 0, 159 | p_dropout: float = 0, 160 | Block: Any = GatedResNet, 161 | no_skip: bool = False, 162 | **kwargs): 163 | """ 164 | VAE: https://arxiv.org/abs/1312.6114 165 | 166 | Define a Variational Autoencoder stage containing: 167 | - a sequence of convolutional blocks for the inference model 168 | - a sequence of convolutional blocks for the generative model 169 | - a stochastic layer 170 | 171 | :param input_shape: dictionary describing the input tensors of shapes (B, H, *D) 172 | :param convolution: list of tuple describing a convolutional block (filters, kernel_size, stride) 173 | :param stochastic: integer or tuple describing the stochastic layer: units or (units, kernel_size, discrete, K) 174 | :param top: is top layer 175 | :param bottom: is bottom layer 176 | :param q_dropout: inference dropout value 177 | :param p_dropout: generative dropout value 178 | :param Block: Block constructor 179 | :param no_skip: do not use skip connections 180 | :param kwargs: others arguments passed to the block constructors (both convolutions and stochastic) 181 | """ 182 | super().__init__(input_shape, convolutions, stochastic, top=top, bottom=bottom, q_dropout=q_dropout, 183 | p_dropout=p_dropout, Block=Block, no_skip=no_skip) 184 | 185 | tensor_shp = input_shape.get('x') 186 | aux_shape = input_shape.get('aux', None) 187 | 188 | # mute skip connections 189 | if no_skip: 190 | aux_shape = None 191 | 192 | # define inference convolutional blocks 193 | in_residual = not bottom 194 | q_skips = [aux_shape for _ in convolutions] if aux_shape is not None else None 195 | self.q_convs = DeterministicBlocks(tensor_shp, convolutions, aux_shape=q_skips, transposed=False, 196 | in_residual=in_residual, Block=Block, dropout=q_dropout, **kwargs) 197 | 198 | # shape of the deterministic output 199 | tensor_shp = self.q_convs.output_shape 200 | 201 | # define the stochastic layer 202 | self.stochastic = StochasticBlock(stochastic, tensor_shp, top=top, **kwargs) 203 | self.q_proj = AsFeatureMap(self.stochastic.output_shape, self.stochastic.input_shape) 204 | 205 | self._q_output_shape = {'x': self.q_proj.output_shape, 'aux': tensor_shp} 206 | 207 | ### GENERATIVE MODEL 208 | 209 | # project z sample 210 | self.p_proj = AsFeatureMap(self.stochastic.output_shape, self.stochastic.input_shape) 211 | 212 | # define the generative convolutional blocks with the skip connections 213 | # here we assume the skip connections to be of the same shape as `tensor_shp` : this does not work with 214 | # with every configuration of the generative model. Making the arhitecture more general requires to have 215 | # a top-down __init__() method such as to take the shapes of the above generative block skip connections as input. 216 | p_skips = None if (top or no_skip) else [tensor_shp] * len(convolutions) 217 | self.p_convs = DeterministicBlocks(self.p_proj.output_shape, convolutions[::-1], 218 | aux_shape=p_skips, transposed=True, 219 | in_residual=False, Block=Block, dropout=p_dropout, **kwargs) 220 | 221 | self._p_output_shape = {'d': self.p_convs.output_shape, 'aux': self.p_convs.hidden_shapes} 222 | 223 | @property 224 | def q_output_shape(self) -> Dict[str, Tuple[int]]: 225 | """size of the output tensors for the inference path""" 226 | return self._q_output_shape 227 | 228 | @property 229 | def p_output_shape(self) -> Tuple[int]: 230 | """size of the output tensor for the generative path""" 231 | return self._p_output_shape 232 | 233 | def infer(self, data: Dict[str, Tensor], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]: 234 | """ 235 | Perform a forward pass through the inference layers and sample the posterior. 236 | 237 | :param data: input data 238 | :param kwargs: additional parameters passed to the stochastic layer 239 | :return: (output data, variational data) 240 | """ 241 | x = data.get('x') 242 | aux = data.get('aux', None) 243 | if self._no_skip: 244 | aux = None 245 | 246 | aux = [aux for _ in range(len(self.q_convs))] if aux is not None else None 247 | x, _ = self.q_convs(x, aux) 248 | 249 | z, q_data = self.stochastic(x, inference=True, **kwargs) 250 | z = self.q_proj(z) 251 | 252 | return {'x': z, 'aux': x}, q_data 253 | 254 | def forward(self, data: dict, posterior: Optional[dict], **kwargs) -> Tuple[ 255 | Dict, Dict[str, List]]: 256 | """ 257 | Perform a forward pass through the generative model and compute KL if posterior data is available 258 | 259 | :param data: data from the above stage forward pass 260 | :param posterior: dictionary representing the posterior from same stage inference pass 261 | :return: (dict('d' : d, 'aux : [aux]), dict('kl': [kl], **auxiliary) ) 262 | """ 263 | d = data.get('d', None) 264 | aux = data.get('aux', None) 265 | if self._no_skip: 266 | aux = None 267 | 268 | # sample p(z | d) 269 | z_p, p_data = self.stochastic(d, inference=False, sample=posterior is None, **kwargs) 270 | 271 | # compute KL(q | p) 272 | if posterior is not None: 273 | loss_data = self.stochastic.loss(posterior, p_data, **kwargs) 274 | z = posterior.get('z') 275 | else: 276 | loss_data = {} 277 | z = z_p 278 | 279 | # project z 280 | z = self.p_proj(z) 281 | 282 | # pass through convolutions 283 | d, skips = self.p_convs(z, aux=aux) 284 | 285 | output_data = {'d': d, 'aux': skips} 286 | return output_data, loss_data 287 | 288 | 289 | class LvaeStage(VaeStage): 290 | def __init__(self, 291 | input_shape: Dict[str, Tuple[int]], 292 | convolutions: List[Tuple[int]], 293 | stochastic: Tuple, 294 | top: bool = False, 295 | bottom: bool = False, 296 | q_dropout: float = 0, 297 | p_dropout: float = 0, 298 | Block: Any = GatedResNet, 299 | **kwargs): 300 | """ 301 | LVAE: https://arxiv.org/abs/1602.02282 302 | 303 | Define a Ladder Variational Autoencoder stage containing: 304 | - a sequence of convolutional blocks for the inference model 305 | - a sequence of convolutional blocks for the generative model 306 | - a stochastic layer 307 | 308 | :param input_shape: dictionary describing the input tensors of shapes (B, H, *D) 309 | :param convolution: list of tuple describing a convolutional block (filters, kernel_size, stride) 310 | :param stochastic: integer or tuple describing the stochastic layer: units or (units, kernel_size, discrete, K) 311 | :param top: is top layer 312 | :param bottom: is bottom layer 313 | :param q_dropout: inference dropout value 314 | :param p_dropout: generative dropout value 315 | :param kwargs: others arguments passed to the block constructors (both convolutions and stochastic) 316 | """ 317 | super().__init__(input_shape, convolutions, stochastic, top=top, bottom=bottom, p_dropout=p_dropout, 318 | q_dropout=q_dropout, Block=Block, **kwargs) 319 | self.q_proj = None 320 | # get the tensor shape of the output of the deterministic path 321 | top_shape = self._q_output_shape.get('aux') 322 | # modify the output of the inference path to be only deterministic 323 | self._q_output_shape['x'] = top_shape 324 | 325 | topdown = top_shape if not top else None 326 | conv = convolutions[-1] 327 | if isinstance(conv, list): 328 | conv = [conv[0], conv[1], 1, conv[-1]] 329 | self.merge = Block(top_shape, conv, aux_shape=topdown, transposed=False, in_residual=True, dropout=p_dropout, 330 | **kwargs) 331 | 332 | def infer(self, data: Dict[str, Tensor], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]: 333 | """ 334 | Perform a forward pass through the inference layers and sample the posterior. 335 | 336 | :param data: input data 337 | :param kwargs: additional parameters passed to the stochastic layer 338 | :return: (output data, variational data) 339 | """ 340 | x = data.get('x') 341 | aux = data.get('aux', None) 342 | if self._no_skip: 343 | aux = None 344 | 345 | aux = [aux for _ in range(len(self.q_convs))] if aux is not None else None 346 | x, _ = self.q_convs(x, aux) 347 | 348 | return {'x': x, 'aux': x}, {'h': x} 349 | 350 | def forward(self, data: dict, posterior: Optional[dict], debugging: bool = False, **kwargs) -> Tuple[ 351 | Dict, Dict[str, List[Tensor]]]: 352 | """ 353 | Perform a forward pass through the generative model and compute KL if posterior data is available 354 | 355 | :param data: data from the above stage forward pass 356 | :param posterior: dictionary representing the posterior 357 | :return: (dict('d' : d, 'aux : [aux]), dict('kl': [kl], **auxiliary) ) 358 | """ 359 | d = data.get('d', None) 360 | 361 | # sample p(z | d) 362 | z_p, p_data = self.stochastic(d, inference=False, sample=posterior is None, **kwargs) 363 | 364 | # sample q(z | h) and compute KL(q | p) 365 | if posterior is not None: 366 | # compute the top-down logits of q(z_i | x, z_{>i}) 367 | h = posterior.get('h') 368 | h = self.merge(h, aux=d) 369 | 370 | # z ~ q(z | h_bu, d_td) 371 | z_q, q_data = self.stochastic(h, inference=True, **kwargs) 372 | 373 | loss_data = self.stochastic.loss(q_data, p_data, **kwargs) 374 | z = z_q 375 | else: 376 | loss_data = {} 377 | z = z_p 378 | 379 | # project z 380 | z = self.p_proj(z) 381 | 382 | # pass through convolutions 383 | aux = data.get('aux', None) 384 | if self._no_skip: 385 | aux = None 386 | 387 | d, skips = self.p_convs(z, aux) 388 | 389 | output_data = {'d': d, 'aux': skips} 390 | return output_data, loss_data 391 | 392 | 393 | class BivaIntermediateStage(BaseStage): 394 | def __init__(self, 395 | input_shape: Dict[str, Tuple[int]], 396 | convolutions: List[Tuple[int]], 397 | stochastic: Union[Dict, Tuple[Dict]], 398 | top: bool = False, 399 | bottom: bool = False, 400 | q_dropout: float = 0, 401 | p_dropout: float = 0, 402 | no_skip: bool = False, 403 | conditional_bu: bool = False, 404 | Block: Any = GatedResNet, 405 | merge_kernel: int = 3, 406 | **kwargs): 407 | """ 408 | BIVA: https://arxiv.org/abs/1902.02102 409 | 410 | Define a Bidirectional Variational Autoencoder stage containing: 411 | - a sequence of convolutional blocks for the bottom-up inference model (BU) 412 | - a sequence of convolutional blocks for the top-down inference model (TD) 413 | - a sequence of convolutional blocks for the generative model 414 | - two stochastic layers (BU and TD) 415 | 416 | :param input_shape: dictionary describing the input tensor shape (B, H, *D) 417 | :param convolution: list of tuple describing a convolutional block (filters, kernel_size, stride) 418 | :param stochastic: dictionary describing the stochastic layer: units or (units, kernel_size, discrete, K) 419 | :param bottom: is bottom layer 420 | :param top: is top layer 421 | :param q_dropout: inference dropout value 422 | :param p_dropout: generative dropout value 423 | :param no_skip: do not use skip connections 424 | :param conditional_bu: condition BU prior on p(z_TD) 425 | :param aux_shape: auxiliary input tensor shape as a tuple of integers (B, H, *D) 426 | :param kwargs: others arguments passed to the block constructors (both convolutions and stochastic) 427 | """ 428 | super().__init__(input_shape, convolutions, stochastic, top=top, bottom=bottom, q_dropout=q_dropout, 429 | p_dropout=p_dropout, Block=Block, no_skip=no_skip) 430 | 431 | self._conditional_bu = conditional_bu 432 | self._merge_kernel = merge_kernel 433 | 434 | if 'x' in input_shape.keys(): 435 | bu_shp = td_shp = input_shape.get('x') 436 | aux_shape = None 437 | else: 438 | bu_shp = input_shape.get('x_bu') 439 | td_shp = input_shape.get('x_td') 440 | aux_shape = input_shape.get('aux') 441 | 442 | if isinstance(stochastic.get('block'), tuple): 443 | bu_block, td_block = stochastic.get('block') 444 | bu_stochastic = copy(stochastic) 445 | td_stochastic = copy(stochastic) 446 | bu_stochastic['block'] = bu_block 447 | td_stochastic['block'] = td_block 448 | else: 449 | bu_stochastic = td_stochastic = stochastic 450 | 451 | # mute skip connections 452 | if no_skip: 453 | aux_shape = None 454 | 455 | # define inference convolutional blocks 456 | in_residual = not bottom 457 | q_bu_aux = [aux_shape for _ in convolutions] if aux_shape is not None else None 458 | self.q_bu_convs = DeterministicBlocks(bu_shp, convolutions, aux_shape=q_bu_aux, transposed=False, 459 | in_residual=in_residual, dropout=q_dropout, Block=Block, **kwargs) 460 | 461 | q_td_aux = [self.q_bu_convs.output_shape for _ in convolutions] 462 | self.q_td_convs = DeterministicBlocks(td_shp, convolutions, aux_shape=q_td_aux, transposed=False, 463 | in_residual=in_residual, dropout=q_dropout, Block=Block, **kwargs) 464 | 465 | # shape of the output of the inference path and input tensor from the generative path 466 | top_tensor_shp = self.q_td_convs.output_shape 467 | aux_shape = shp_cat([top_tensor_shp, top_tensor_shp], 1) 468 | 469 | # define the BU stochastic layer 470 | bu_top = False if conditional_bu else top 471 | self.bu_stochastic = StochasticBlock(bu_stochastic, top_tensor_shp, top=bu_top, **kwargs) 472 | self.bu_proj = AsFeatureMap(self.bu_stochastic.output_shape, self.bu_stochastic.input_shape, **kwargs) 473 | 474 | # define the TD stochastic layer 475 | self.td_stochastic = StochasticBlock(td_stochastic, top_tensor_shp, top=top, **kwargs) 476 | 477 | self._q_output_shape = {'x_bu': self.bu_proj.output_shape, 478 | 'x_td': top_tensor_shp, 479 | 'aux': aux_shape} 480 | 481 | ### GENERATIVE MODEL 482 | 483 | # TD merge layer 484 | h_shape = self._q_output_shape.get('x_td', None) if not self._top else None 485 | conv = self._convolutions[::-1][-1] 486 | if isinstance(conv, list) or isinstance(conv, tuple): 487 | conv = [conv[0], merge_kernel, 1, 488 | conv[-1]] # in the original implementation, this depends on the parameters of the above layers 489 | self.merge = Block(h_shape, conv, aux_shape=h_shape, transposed=False, in_residual=True, 490 | dropout=p_dropout, 491 | **kwargs) 492 | 493 | # alternative: define the condition p(z_bu | z_td, ...) 494 | if conditional_bu: 495 | self.bu_condition = Block(self.bu_stochastic.output_shape, conv, aux_shape=h_shape, transposed=False, 496 | in_residual=False, 497 | dropout=p_dropout, **kwargs) 498 | else: 499 | self.bu_condition = None 500 | 501 | # merge latent variables 502 | z_shp = shp_cat([self.bu_stochastic.output_shape, self.td_stochastic.output_shape], 1) 503 | self.z_proj = AsFeatureMap(z_shp, self.bu_stochastic.input_shape) 504 | 505 | # define the generative convolutional blocks with the skip connections 506 | # here we assume the skip connections to be of the same shape as `top_tensor_shape` : this does not work with 507 | # with every configuration of the generative model. Making the arhitecture more general requires to have 508 | # a top-down __init__() method such as to take the shapes of the above generative block skip connections as input. 509 | p_skips = None if (top or no_skip) else [top_tensor_shp] * len(convolutions) 510 | self.p_convs = DeterministicBlocks(self.z_proj.output_shape, self._convolutions[::-1], 511 | aux_shape=p_skips, transposed=True, 512 | in_residual=False, Block=Block, dropout=p_dropout, **kwargs) 513 | 514 | self._p_output_shape = {'d': self.p_convs.output_shape, 'aux': self.p_convs.hidden_shapes} 515 | 516 | @property 517 | def q_output_shape(self) -> Dict[str, Tuple[int]]: 518 | """size of the output tensors for the inference path""" 519 | return self._q_output_shape 520 | 521 | @property 522 | def p_output_shape(self) -> Tuple[int]: 523 | """size of the output tensor for the generative path""" 524 | return self._p_output_shape 525 | 526 | def infer(self, data: Dict[str, Tensor], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]: 527 | """ 528 | Perform a forward pass through the inference layers and sample the posterior. 529 | 530 | :param data: input data 531 | :param kwargs: additional parameters passed to the stochastic layer 532 | :return: (output data, variational data) 533 | """ 534 | if 'x' in data.keys(): 535 | x = data.get('x') 536 | x_bu, x_td = x, x 537 | else: 538 | x_bu = data.get('x_bu') 539 | x_td = data.get('x_td') 540 | 541 | aux = data.get('aux', None) 542 | if self._no_skip: 543 | aux = None 544 | 545 | # BU path 546 | bu_aux = [aux for _ in range(len(self.q_bu_convs))] if aux is not None else None 547 | x_bu, _ = self.q_bu_convs(x_bu, aux=bu_aux) 548 | # z_bu ~ q(x) 549 | z_bu, bu_q_data = self.bu_stochastic(x_bu, inference=True, **kwargs) 550 | z_bu_proj = self.bu_proj(z_bu) 551 | 552 | # TD path 553 | td_aux = [x_bu for _ in range(len(self.q_td_convs))] 554 | x_td, _ = self.q_td_convs(x_td, aux=td_aux) 555 | td_q_data = {'z': z_bu, 'h': x_td} # note h = d_q(x) 556 | 557 | # skip connection 558 | aux = torch.cat([x_bu, x_td], 1) 559 | 560 | return {'x_bu': z_bu_proj, 'x_td': x_td, 'aux': aux}, {'z_bu': z_bu, 'bu': bu_q_data, 'td': td_q_data} 561 | 562 | def forward(self, data: Dict, posterior: Optional[dict], debugging: bool = False, **kwargs) -> Tuple[ 563 | Dict, Dict[str, List[Tensor]]]: 564 | """ 565 | Perform a forward pass through the generative model and compute KL if posterior data is available 566 | 567 | :param d: previous hidden state 568 | :param posterior: dictionary representing the posterior 569 | :return: (hidden state, dict('kl': [kl], **auxiliary)) 570 | """ 571 | 572 | d = data.get('d', None) 573 | 574 | if posterior is not None: 575 | # sample posterior and compute KL using prior 576 | bu_q_data = posterior.get('bu') 577 | td_q_data = posterior.get('td') 578 | z_bu_q = posterior.get('z_bu') 579 | 580 | # top-down: compute the posterior using the bottom-up hidden state and top-down hidden state 581 | # p(z_td | d_top) 582 | _, td_p_data = self.td_stochastic(d, inference=False, sample=False, **kwargs) 583 | 584 | # merge d_top with h = d_q(x) 585 | h = td_q_data.get('h') 586 | h = self.merge(h, aux=d) 587 | 588 | # z_td ~ q(z_td | h_bu_td) 589 | z_td_q, td_q_data = self.td_stochastic(h, inference=True, sample=True, **kwargs) 590 | 591 | # compute log q_bu(z_i | x) - log p_bu(z_i) (+ additional data) 592 | td_loss_data = self.td_stochastic.loss(td_q_data, td_p_data, **kwargs) 593 | 594 | # conditional BU prior 595 | if self.bu_condition is not None: 596 | d_ = self.bu_condition(z_td_q, aux=d) 597 | else: 598 | d_ = d 599 | 600 | # bottom-up: retrieve data from the inference path 601 | # z_bu ~ p(d_top) 602 | _, bu_p_data = self.bu_stochastic(d_, inference=False, sample=False, **kwargs) 603 | 604 | # compute log q_td(z_i | x, z_{>i}) - log p_td(z_i) (+ additional data) 605 | bu_loss_data = self.bu_stochastic.loss(bu_q_data, bu_p_data, **kwargs) 606 | 607 | # merge samples 608 | z = torch.cat([z_td_q, z_bu_q], 1) 609 | 610 | else: 611 | # sample priors 612 | # top-down 613 | z_td_p, td_p_data = self.td_stochastic(d, inference=False, sample=True, **kwargs) # prior 614 | 615 | # conditional BU prior 616 | if self.bu_condition is not None: 617 | d_ = self.bu_condition(z_td_p, aux=d) 618 | else: 619 | d_ = d 620 | 621 | # bottom-up 622 | z_bu_p, bu_p_data = self.bu_stochastic(d_, inference=False, sample=True, **kwargs) # prior 623 | 624 | bu_loss_data, td_loss_data = {}, {} 625 | 626 | # merge samples 627 | z = torch.cat([z_td_p, z_bu_p], 1) 628 | 629 | # projection 630 | z = self.z_proj(z) 631 | 632 | # pass through convolutions 633 | aux = data.get('aux', None) 634 | if self._no_skip: 635 | aux = None 636 | 637 | d, skips = self.p_convs(z, aux=aux) 638 | 639 | # gather data 640 | loss_data = DataCollector() 641 | loss_data.extend(td_loss_data) 642 | loss_data.extend(bu_loss_data) 643 | 644 | output_data = {'d': d, 'aux': skips} 645 | return output_data, loss_data 646 | 647 | 648 | class BivaTopStage_simpler(VaeStage): 649 | """ 650 | This is the BivaTopStage without the additional BU-TD merge layer. 651 | """ 652 | 653 | def __init__(self, input_shape: Dict[str, Tuple[int]], *args, **kwargs): 654 | bu_shp = input_shape.get('x_bu') 655 | td_shp = input_shape.get('x_td') 656 | 657 | tensor_shp = shp_cat([bu_shp, td_shp], 1) 658 | concat_shape = {'x': tensor_shp} 659 | 660 | super().__init__(concat_shape, *args, **kwargs) 661 | 662 | def infer(self, data: Dict[str, Tensor], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]: 663 | x_bu = data.pop('x_bu') 664 | x_td = data.pop('x_td') 665 | data['x'] = torch.cat([x_bu, x_td], 1) 666 | 667 | return super().infer(data, **kwargs) 668 | 669 | 670 | class BivaTopStage(BaseStage): 671 | def __init__(self, input_shape: Dict[str, Tuple[int]], 672 | convolutions: List[Tuple[int]], 673 | stochastic: Union[Dict, Tuple[Dict]], 674 | top: bool = False, 675 | bottom: bool = False, 676 | q_dropout: float = 0, 677 | p_dropout: float = 0, 678 | no_skip: bool = False, 679 | Block: Any = GatedResNet, 680 | **kwargs): 681 | """ 682 | BIVA: https://arxiv.org/abs/1902.02102 683 | 684 | Define a Bidirectional Variational Autoencoder top stage containing: 685 | - a sequence of convolutional blocks for the bottom-up inference model (BU) 686 | - a sequence of convolutional blocks for the top-down inference model (TD) 687 | - a convolutional block to merge BU and TD 688 | - a sequence of convolutional blocks for the generative model 689 | - a stochastic layer (z_L) 690 | 691 | :param input_shape: dictionary describing the input tensor shape (B, H, *D) 692 | :param convolution: list of tuple describing a convolutional block (filters, kernel_size, stride) 693 | :param stochastic: dictionary describing the stochastic layer: units or (units, kernel_size, discrete, K) 694 | :param bottom: is bottom layer 695 | :param top: is top layer 696 | :param q_dropout: inference dropout value 697 | :param p_dropout: generative dropout value 698 | :param no_skip: do not use skip connections 699 | :param aux_shape: auxiliary input tensor shape as a tuple of integers (B, H, *D) 700 | :param kwargs: others arguments passed to the block constructors (both convolutions and stochastic) 701 | """ 702 | super().__init__(input_shape, convolutions, stochastic, top=top, bottom=bottom, q_dropout=q_dropout, 703 | p_dropout=p_dropout, Block=Block, no_skip=no_skip) 704 | top = True 705 | 706 | if 'x' in input_shape.keys(): 707 | bu_shp = td_shp = input_shape.get('x') 708 | aux_shape = None 709 | else: 710 | bu_shp = input_shape.get('x_bu') 711 | td_shp = input_shape.get('x_td') 712 | aux_shape = input_shape.get('aux') 713 | 714 | # mute skip connections 715 | if no_skip: 716 | aux_shape = None 717 | 718 | # define inference BU and TD paths 719 | in_residual = not bottom 720 | q_bu_aux = [aux_shape for _ in convolutions] if aux_shape is not None else None 721 | self.q_bu_convs = DeterministicBlocks(bu_shp, convolutions, aux_shape=q_bu_aux, transposed=False, 722 | in_residual=in_residual, dropout=q_dropout, Block=Block, **kwargs) 723 | 724 | q_td_aux = [self.q_bu_convs.output_shape for _ in convolutions] 725 | self.q_td_convs = DeterministicBlocks(td_shp, convolutions, aux_shape=q_td_aux, transposed=False, 726 | in_residual=in_residual, dropout=q_dropout, Block=Block, **kwargs) 727 | 728 | # merge BU and TD paths 729 | conv = convolutions[-1] 730 | self.q_top = Block(shp_cat([self.q_bu_convs.output_shape, self.q_td_convs.output_shape], 1), 731 | [conv[0], conv[1], 1, conv[-1]], dropout=q_dropout, 732 | residual=True, **kwargs) 733 | top_tensor_shp = self.q_top.output_shape 734 | 735 | # stochastic layer 736 | self.stochastic = StochasticBlock(stochastic, top_tensor_shp, top=top, **kwargs) 737 | 738 | self._q_output_shape = {} # no output shape (top layer) 739 | 740 | ### GENERATIVE MODEL 741 | 742 | # map sample back to a feature map 743 | self.z_proj = AsFeatureMap(self.stochastic.output_shape, self.stochastic.input_shape) 744 | 745 | # define the generative convolutional blocks with the skip connections 746 | p_skips = None 747 | self.p_convs = DeterministicBlocks(self.z_proj.output_shape, self._convolutions[::-1], 748 | aux_shape=p_skips, transposed=True, 749 | in_residual=False, Block=Block, dropout=p_dropout, **kwargs) 750 | 751 | self._p_output_shape = {'d': self.p_convs.output_shape, 'aux': self.p_convs.hidden_shapes} 752 | 753 | def infer(self, data: Dict[str, Tensor], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]: 754 | """ 755 | Perform a forward pass through the inference layers and sample the posterior. 756 | 757 | :param data: input data 758 | :param kwargs: additional parameters passed to the stochastic layer 759 | :return: (output data, variational data) 760 | """ 761 | 762 | if 'x' in data.keys(): 763 | x = data.get('x') 764 | x_bu, x_td = x, x 765 | else: 766 | x_bu = data.get('x_bu') 767 | x_td = data.get('x_td') 768 | 769 | aux = data.get('aux', None) 770 | if self._no_skip: 771 | aux = None 772 | 773 | # BU path 774 | bu_aux = [aux for _ in range(len(self.q_bu_convs))] if aux is not None else None 775 | x_bu, _ = self.q_bu_convs(x_bu, aux=bu_aux) 776 | 777 | # TD path 778 | td_aux = [x_bu for _ in range(len(self.q_td_convs))] 779 | x_td, _ = self.q_td_convs(x_td, aux=td_aux) 780 | 781 | # merge BU and TD 782 | x = torch.cat([x_bu, x_td], 1) 783 | x = self.q_top(x) 784 | 785 | # sample top layer 786 | z, q_data = self.stochastic(x, inference=True, **kwargs) 787 | 788 | return {}, q_data 789 | 790 | def forward(self, data: Dict, posterior: Optional[dict], **kwargs) -> Tuple[ 791 | Dict, Dict[str, List]]: 792 | """ 793 | Perform a forward pass through the generative model and compute KL if posterior data is available 794 | 795 | :param data: data from the above stage forward pass 796 | :param posterior: dictionary representing the posterior 797 | :return: (hidden state, dict('kl': [kl], **auxiliary) ) 798 | """ 799 | d = data.get('d', None) 800 | 801 | if posterior is not None: 802 | # get p(z | d) 803 | _, p_data = self.stochastic(d, inference=False, sample=False, **kwargs) 804 | 805 | # compute KL(q | p) 806 | loss_data = self.stochastic.loss(posterior, p_data, **kwargs) 807 | z = posterior.get('z') 808 | else: 809 | loss_data = {} 810 | z, p_data = self.stochastic(d, inference=False, sample=True, **kwargs) 811 | 812 | # project z 813 | z = self.z_proj(z) 814 | 815 | # pass through convolutions 816 | aux = data.get('aux', None) 817 | if self._no_skip: 818 | aux = None 819 | 820 | d, skips = self.p_convs(z, aux=aux) 821 | 822 | output_data = {'d': d, 'aux': skips} 823 | return output_data, loss_data 824 | 825 | @property 826 | def q_output_shape(self) -> Dict[str, Tuple[int]]: 827 | """size of the output tensors for the inference path""" 828 | return self._q_output_shape 829 | 830 | @property 831 | def p_output_shape(self) -> Tuple[int]: 832 | """size of the output tensor for the generative path""" 833 | return self._p_output_shape 834 | 835 | 836 | def BivaStage(input_shape: Dict[str, Tuple[int]], 837 | convolutions: List[Tuple[int]], 838 | stochastic: Union[Dict, Tuple[Dict]], 839 | top: bool = False, 840 | **kwargs): 841 | """ 842 | BIVA: https://arxiv.org/abs/1902.02102 843 | 844 | Define a Bidirectional Variational Autoencoder stage containing: 845 | - a sequence of convolutional blocks for the bottom-up inference model (BU) 846 | - a sequence of convolutional blocks for the top-down inference model (TD) 847 | - a sequence of convolutional blocks for the generative model 848 | - two stochastic layers (BU and TD) 849 | 850 | This is not an op-for-op implementation of the original Tensorflow version. 851 | 852 | :param input_shape: dictionary describing the input tensor shape (B, H, *D) 853 | :param convolution: list of tuple describing a convolutional block (filters, kernel_size, stride) 854 | :param stochastic: dictionary describing the stochastic layer: units or (units, kernel_size, discrete, K) 855 | :param top: is top layer 856 | :param bottom: is bottom layer 857 | :param q_dropout: inference dropout value 858 | :param p_dropout: generative dropout value 859 | :param conditional_bu: condition BU prior on p(z_TD) 860 | :param aux_shape: auxiliary input tensor shape as a tuple of integers (B, H, *D) 861 | :param kwargs: others arguments passed to the block constructors (both convolutions and stochastic) 862 | """ 863 | 864 | if top: 865 | return BivaTopStage(input_shape, convolutions, stochastic, top=top, **kwargs) 866 | else: 867 | return BivaIntermediateStage(input_shape, convolutions, stochastic, top=top, **kwargs) 868 | -------------------------------------------------------------------------------- /biva/model/stochastic.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | from torch import nn, Tensor 5 | from torch.distributions import Normal 6 | 7 | from ..layers import PaddedNormedConv, NormedDense 8 | from ..utils import batch_reduce 9 | 10 | 11 | class StochasticLayer(nn.Module): 12 | """ 13 | An abstract class of a VAE stochastic layer. 14 | """ 15 | 16 | def __init__(self, data: Dict, tensor_shp: Tuple[int], **kwargs: Any): 17 | super().__init__() 18 | self._output_shape = None 19 | self._input_shape = tensor_shp 20 | 21 | def forward(self, x: Optional[Tensor], inference: bool, sample: bool = True, N: Optional[int] = None, **kwargs) -> \ 22 | Tuple[ 23 | Tensor, Dict[str, Any]]: 24 | """ 25 | Returns the distribution parametrized by x and sample if `sample1`=True. If no hidden state is provided, uses the prior. 26 | :param x: hidden state used to computed logits (Optional : None means using the prior) 27 | :param inference: inference mode 28 | :param sample: sample layer 29 | :param N: number of samples (when sampling from prior) 30 | :param kwargs: additional args passed ot the stochastic layer 31 | :return: (projected sample, data) 32 | """ 33 | raise NotImplementedError 34 | 35 | def loss(self, q_data: Dict[str, Any], p_data: Dict[str, Any], **kwargs: Any) -> Dict[str, List[Any]]: 36 | """ 37 | Compute the KL divergence and other auxiliary losses if required 38 | :param q_data: data received from the posterior forward pass 39 | :param p_data: data received from the prior forward pass 40 | :param kwargs: other parameters passed to the kl function 41 | :return: dictionary of losses {'kl': [values], 'auxiliary' : [aux_values], ...} 42 | """ 43 | raise NotImplementedError 44 | 45 | @property 46 | def output_shape(self): 47 | return self._output_shape 48 | 49 | @property 50 | def input_shape(self): 51 | return self._input_shape 52 | 53 | 54 | class DenseNormal(StochasticLayer): 55 | """ 56 | A Normal stochastic layer parametrized by dense layers. 57 | """ 58 | 59 | def __init__(self, data: Dict, tensor_shp: Tuple[int], top: bool = False, act: nn.Module = nn.ELU, 60 | weightnorm: bool = True, log_var_act: Optional[Callable] = nn.Softplus, **kwargs): 61 | super().__init__(data, tensor_shp) 62 | 63 | self._input_shape = tensor_shp 64 | 65 | self.eps = 1e-8 66 | self.nz = data.get('N') 67 | self.tensor_shp = tensor_shp 68 | self.dim = 2 69 | self.act = act() 70 | self.log_var_act = log_var_act() if log_var_act is not None else None 71 | 72 | # stochastic layer and prior 73 | if top: 74 | prior = torch.zeros((2 * self.nz)) 75 | self.register_buffer('prior', prior) 76 | 77 | # computes logits 78 | nz_in = 2 * self.nz 79 | self.qx2z = NormedDense(tensor_shp, nz_in, weightnorm=weightnorm) 80 | if not top: 81 | self.px2z = NormedDense(tensor_shp, nz_in, weightnorm=weightnorm) 82 | 83 | self._output_shape = (-1, self.nz) 84 | 85 | @property 86 | def output_shape(self): 87 | return self._output_shape 88 | 89 | @property 90 | def input_shape(self): 91 | return self._input_shape 92 | 93 | def compute_logits(self, x: Tensor, inference: bool) -> Tuple[Tensor, Tensor]: 94 | """ 95 | Compute the logits of the distribution. 96 | :param x: input tensor 97 | :param inference: inference mode 98 | :return: logits 99 | """ 100 | x = self.act(x) 101 | 102 | if inference: 103 | logits = self.qx2z(x) 104 | else: 105 | logits = self.px2z(x) 106 | 107 | # apply activation to logvar 108 | mu, logvar = logits.chunk(2, dim=1) 109 | if self.log_var_act is not None: 110 | logvar = self.log_var_act(logvar) 111 | return mu, logvar 112 | 113 | def forward(self, x: Optional[Tensor], inference: bool, sample: bool = True, N: Optional[int] = None, **kwargs) -> \ 114 | Tuple[Tensor, Dict[str, Any]]: 115 | 116 | if x is None: 117 | mu, logvar = self.prior.expand(N, *self.prior.shape).chunk(2, dim=1) 118 | else: 119 | mu, logvar = self.compute_logits(x, inference) 120 | 121 | # sample layer 122 | std = logvar.mul(0.5).exp() 123 | dist = Normal(mu, std) 124 | 125 | z = dist.rsample() if sample else None 126 | 127 | return z, {'z': z, 'dist': dist} 128 | 129 | def loss(self, q_data: Dict[str, Any], p_data: Dict[str, Any], **kwargs: Any) -> Dict[str, List]: 130 | z_q = q_data.get('z') 131 | q = q_data.get('dist') 132 | p = p_data.get('dist') 133 | 134 | kl = q.log_prob(z_q) - p.log_prob(z_q) 135 | kl = batch_reduce(kl) 136 | 137 | return {'kl': [kl]} 138 | 139 | 140 | class ConvNormal(StochasticLayer): 141 | """ 142 | A Normal stochastic layer parametrized by convolutions. 143 | """ 144 | 145 | def __init__(self, data: Dict, tensor_shp: Tuple[int], top: bool = False, act: nn.Module = nn.ELU, 146 | learn_prior: bool = False, weightnorm: bool = True, log_var_act: Optional[Callable] = nn.Softplus, 147 | **kwargs): 148 | super().__init__(data, tensor_shp) 149 | 150 | self.eps = 1e-8 151 | nhid = tensor_shp[1] 152 | self.nz = data.get('N') 153 | kernel_size = data.get('kernel') 154 | self.tensor_shp = tensor_shp 155 | self.input_shp = tensor_shp 156 | self.act = act() 157 | self.log_var_act = log_var_act() if log_var_act is not None else None 158 | 159 | # prior 160 | if top: 161 | prior = torch.zeros((2 * self.nz, *tensor_shp[2:])) 162 | 163 | if learn_prior: 164 | self.prior = nn.Parameter(prior) 165 | else: 166 | self.register_buffer('prior', prior) 167 | 168 | # computes logits 169 | nz_in = 2 * self.nz 170 | self.qx2z = PaddedNormedConv(tensor_shp, nn.Conv2d(nhid, nz_in, kernel_size), weightnorm=weightnorm) 171 | if not top: 172 | self.px2z = PaddedNormedConv(tensor_shp, nn.Conv2d(nhid, nz_in, kernel_size), weightnorm=weightnorm) 173 | 174 | # compute output shape 175 | nz_out = self.nz 176 | out_shp = (-1, nz_out, *tensor_shp[2:]) 177 | self._output_shape = out_shp 178 | self._input_shape = tensor_shp 179 | 180 | @property 181 | def output_shape(self): 182 | return self._output_shape 183 | 184 | @property 185 | def input_shape(self): 186 | return self._input_shape 187 | 188 | def compute_logits(self, x: Tensor, inference: bool) -> Tuple[Tensor, Tensor]: 189 | """ 190 | Compute the logits of the distribution. 191 | :param x: input tensor 192 | :param inference: inference mode 193 | :return: logits 194 | """ 195 | x = self.act(x) 196 | if inference: 197 | logits = self.qx2z(x) 198 | else: 199 | logits = self.px2z(x) 200 | 201 | # apply activation to logvar 202 | mu, logvar = logits.chunk(2, dim=1) 203 | if self.log_var_act is not None: 204 | logvar = self.log_var_act(logvar) 205 | return mu, logvar 206 | 207 | def expand_prior(self, batch_size: int): 208 | return self.prior.expand(batch_size, *self.prior.shape).chunk(2, dim=1) 209 | 210 | def forward(self, x: Optional[Tensor], inference: bool, sample: bool = True, N: Optional[int] = None, **kwargs) -> \ 211 | Tuple[ 212 | Tensor, Dict[str, Any]]: 213 | 214 | if x is None: 215 | mu, logvar = self.expand_prior(N) 216 | else: 217 | mu, logvar = self.compute_logits(x, inference) 218 | 219 | # sample layer 220 | std = logvar.mul(0.5).exp() 221 | dist = Normal(mu, std) 222 | 223 | z = dist.rsample() if sample else None 224 | 225 | return z, {'z': z, 'dist': dist} 226 | 227 | def loss(self, q_data: Dict[str, Any], p_data: Dict[str, Any], **kwargs: Any) -> Dict[str, List]: 228 | z_q = q_data.get('z') 229 | q = q_data.get('dist') 230 | p = p_data.get('dist') 231 | 232 | kl = q.log_prob(z_q) - p.log_prob(z_q) 233 | kl = batch_reduce(kl) 234 | 235 | return {'kl': [kl]} 236 | -------------------------------------------------------------------------------- /biva/model/utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import * 3 | 4 | import torch 5 | 6 | 7 | class DataCollector(defaultdict): 8 | def __init__(self): 9 | super().__init__(list) 10 | 11 | def extend(self, data: Dict[str, List[Optional[torch.Tensor]]]) -> None: 12 | """Append new data item""" 13 | for key, d in data.items(): 14 | self[key] += d 15 | 16 | def sort(self) -> Dict[str, List[Optional[torch.Tensor]]]: 17 | """sort data and return""" 18 | for key, d in self.items(): 19 | d = d[::-1] 20 | self[key] = [t for t in d if t is not None] 21 | 22 | return self 23 | -------------------------------------------------------------------------------- /biva/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logging import * 2 | from .ops import * 3 | from .utils import * 4 | from .discretized_mixture_logits import DiscretizedMixtureLogits 5 | -------------------------------------------------------------------------------- /biva/utils/discretized_mixture_logits.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.distributions import Distribution 6 | 7 | 8 | class DiscretizedMixtureLogitsDistribution(Distribution): 9 | def __init__(self, nr_mix, logits): 10 | self.logits = logits 11 | self.nr_mix = nr_mix 12 | 13 | def log_prob(self, value): 14 | return - discretized_mix_logistic_loss(value, self.logits) 15 | 16 | def sample(self): 17 | return sample_from_discretized_mix_logistic(self.logits, self.nr_mix) 18 | 19 | 20 | class DiscretizedMixtureLogits(): 21 | 22 | def __init__(self, nr_mix, **kwargs): 23 | self.nr_mix = nr_mix 24 | 25 | def __call__(self, logits): 26 | return DiscretizedMixtureLogitsDistribution(self.nr_mix, logits) 27 | 28 | 29 | # copied from: https://github.com/pclucas14/pixel-cnn-pp/blob/master/utils.py 30 | 31 | 32 | def concat_elu(x): 33 | """ like concatenated ReLU (http://arxiv.org/abs/1603.05201), but then with ELU """ 34 | # Pytorch ordering 35 | axis = len(x.size()) - 3 36 | return F.elu(torch.cat([x, -x], dim=axis)) 37 | 38 | 39 | def log_sum_exp(x): 40 | """ numerically stable log_sum_exp implementation that prevents overflow """ 41 | # TF ordering 42 | axis = len(x.size()) - 1 43 | m, _ = torch.max(x, dim=axis) 44 | m2, _ = torch.max(x, dim=axis, keepdim=True) 45 | return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) 46 | 47 | 48 | def log_prob_from_logits(x): 49 | """ numerically stable log_softmax implementation that prevents overflow """ 50 | # TF ordering 51 | axis = len(x.size()) - 1 52 | m, _ = torch.max(x, dim=axis, keepdim=True) 53 | return x - m - torch.log(torch.sum(torch.exp(x - m), dim=axis, keepdim=True)) 54 | 55 | 56 | def discretized_mix_logistic_loss(x, l): 57 | """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """ 58 | # Pytorch ordering 59 | x = x.permute(0, 2, 3, 1) 60 | l = l.permute(0, 2, 3, 1) 61 | xs = [int(y) for y in x.size()] 62 | ls = [int(y) for y in l.size()] 63 | 64 | # here and below: unpacking the params of the mixture of logistics 65 | nr_mix = int(ls[-1] / 10) 66 | logit_probs = l[:, :, :, :nr_mix] 67 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3]) # 3 for mean, scale, coef 68 | means = l[:, :, :, :, :nr_mix] 69 | # log_scales = torch.max(l[:, :, :, :, nr_mix:2 * nr_mix], -7.) 70 | log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.) 71 | 72 | coeffs = torch.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) 73 | # here and below: getting the means and adjusting them based on preceding 74 | # sub-pixels 75 | x = x.contiguous() 76 | x = x.unsqueeze(-1) + torch.zeros(xs + [nr_mix], device=x.device) 77 | m2 = (means[:, :, :, 1, :] + coeffs[:, :, :, 0, :] 78 | * x[:, :, :, 0, :]).view(xs[0], xs[1], xs[2], 1, nr_mix) 79 | 80 | m3 = (means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + 81 | coeffs[:, :, :, 2, :] * x[:, :, :, 1, :]).view(xs[0], xs[1], xs[2], 1, nr_mix) 82 | 83 | means = torch.cat((means[:, :, :, 0, :].unsqueeze(3), m2, m3), dim=3) 84 | centered_x = x - means 85 | inv_stdv = torch.exp(-log_scales) 86 | plus_in = inv_stdv * (centered_x + 1. / 255.) 87 | cdf_plus = torch.sigmoid(plus_in) 88 | min_in = inv_stdv * (centered_x - 1. / 255.) 89 | cdf_min = torch.sigmoid(min_in) 90 | # log probability for edge case of 0 (before scaling) 91 | log_cdf_plus = plus_in - F.softplus(plus_in) 92 | # log probability for edge case of 255 (before scaling) 93 | log_one_minus_cdf_min = -F.softplus(min_in) 94 | cdf_delta = cdf_plus - cdf_min # probability for all other cases 95 | mid_in = inv_stdv * centered_x 96 | # log probability in the center of the bin, to be used in extreme cases 97 | # (not actually used in our code) 98 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) 99 | 100 | # now select the right output: left edge case, right edge case, normal 101 | # case, extremely low prob case (doesn't actually happen for us) 102 | 103 | # this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select() 104 | # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta))) 105 | 106 | # robust version, that still works if probabilities are below 1e-5 (which never happens in our code) 107 | # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs 108 | # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue 109 | # if the probability on a sub-pixel is below 1e-5, we use an approximation 110 | # based on the assumption that the log-density is constant in the bin of 111 | # the observed sub-pixel value 112 | 113 | inner_inner_cond = (cdf_delta > 1e-5).float() 114 | inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * ( 115 | log_pdf_mid - np.log(127.5)) 116 | inner_cond = (x > 0.999).float() 117 | inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out 118 | cond = (x < -0.999).float() 119 | log_probs = cond * log_cdf_plus + (1. - cond) * inner_out 120 | log_probs = torch.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs) 121 | 122 | return - log_sum_exp(log_probs) 123 | 124 | 125 | def discretized_mix_logistic_loss_1d(x, l): 126 | """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """ 127 | # Pytorch ordering 128 | x = x.permute(0, 2, 3, 1) 129 | l = l.permute(0, 2, 3, 1) 130 | xs = [int(y) for y in x.size()] 131 | ls = [int(y) for y in l.size()] 132 | 133 | # here and below: unpacking the params of the mixture of logistics 134 | nr_mix = int(ls[-1] / 3) 135 | logit_probs = l[:, :, :, :nr_mix] 136 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 2]) # 2 for mean, scale 137 | means = l[:, :, :, :, :nr_mix] 138 | log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.) 139 | # here and below: getting the means and adjusting them based on preceding 140 | # sub-pixels 141 | x = x.contiguous() 142 | x = x.unsqueeze(-1) + Variable(torch.zeros(xs + [nr_mix]).cuda(), requires_grad=False) 143 | 144 | # means = torch.cat((means[:, :, :, 0, :].unsqueeze(3), m2, m3), dim=3) 145 | centered_x = x - means 146 | inv_stdv = torch.exp(-log_scales) 147 | plus_in = inv_stdv * (centered_x + 1. / 255.) 148 | cdf_plus = torch.sigmoid(plus_in) 149 | min_in = inv_stdv * (centered_x - 1. / 255.) 150 | cdf_min = torch.sigmoid(min_in) 151 | # log probability for edge case of 0 (before scaling) 152 | log_cdf_plus = plus_in - F.softplus(plus_in) 153 | # log probability for edge case of 255 (before scaling) 154 | log_one_minus_cdf_min = -F.softplus(min_in) 155 | cdf_delta = cdf_plus - cdf_min # probability for all other cases 156 | mid_in = inv_stdv * centered_x 157 | # log probability in the center of the bin, to be used in extreme cases 158 | # (not actually used in our code) 159 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) 160 | 161 | inner_inner_cond = (cdf_delta > 1e-5).float() 162 | inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * ( 163 | log_pdf_mid - np.log(127.5)) 164 | inner_cond = (x > 0.999).float() 165 | inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out 166 | cond = (x < -0.999).float() 167 | log_probs = cond * log_cdf_plus + (1. - cond) * inner_out 168 | log_probs = torch.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs) 169 | 170 | return -torch.sum(log_sum_exp(log_probs)) 171 | 172 | 173 | def to_one_hot(tensor, n, fill_with=1.): 174 | # we perform one hot encore with respect to the last axis 175 | one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_() 176 | if tensor.is_cuda: one_hot = one_hot.cuda() 177 | one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) 178 | return Variable(one_hot) 179 | 180 | 181 | def sample_from_discretized_mix_logistic_1d(l, nr_mix): 182 | # Pytorch ordering 183 | l = l.permute(0, 2, 3, 1) 184 | ls = [int(y) for y in l.size()] 185 | xs = ls[:-1] + [1] # [3] 186 | 187 | # unpack parameters 188 | logit_probs = l[:, :, :, :nr_mix] 189 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 2]) # for mean, scale 190 | 191 | # sample mixture indicator from softmax 192 | temp = torch.FloatTensor(logit_probs.size()) 193 | if l.is_cuda: temp = temp.cuda() 194 | temp.uniform_(1e-5, 1. - 1e-5) 195 | temp = logit_probs.data - torch.log(- torch.log(temp)) 196 | _, argmax = temp.max(dim=3) 197 | 198 | one_hot = to_one_hot(argmax, nr_mix) 199 | sel = one_hot.view(xs[:-1] + [1, nr_mix]) 200 | # select logistic parameters 201 | means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4) 202 | log_scales = torch.clamp(torch.sum( 203 | l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.) 204 | u = torch.FloatTensor(means.size()) 205 | if l.is_cuda: u = u.cuda() 206 | u.uniform_(1e-5, 1. - 1e-5) 207 | u = Variable(u) 208 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) 209 | x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-1.), max=1.) 210 | out = x0.unsqueeze(1) 211 | return out 212 | 213 | 214 | def sample_from_discretized_mix_logistic(l, nr_mix): 215 | # Pytorch ordering 216 | l = l.permute(0, 2, 3, 1) 217 | ls = [int(y) for y in l.size()] 218 | xs = ls[:-1] + [3] 219 | 220 | # unpack parameters 221 | logit_probs = l[:, :, :, :nr_mix] 222 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3]) 223 | # sample mixture indicator from softmax 224 | temp = torch.FloatTensor(logit_probs.size()) 225 | if l.is_cuda: temp = temp.cuda() 226 | temp.uniform_(1e-5, 1. - 1e-5) 227 | temp = logit_probs.data - torch.log(- torch.log(temp)) 228 | _, argmax = temp.max(dim=3) 229 | 230 | one_hot = to_one_hot(argmax, nr_mix) 231 | sel = one_hot.view(xs[:-1] + [1, nr_mix]) 232 | # select logistic parameters 233 | means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4) 234 | log_scales = torch.clamp(torch.sum( 235 | l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.) 236 | coeffs = torch.sum(torch.tanh( 237 | l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) * sel, dim=4) 238 | # sample from logistic & clip to interval 239 | # we don't actually round to the nearest 8bit value when sampling 240 | u = torch.FloatTensor(means.size()) 241 | if l.is_cuda: u = u.cuda() 242 | u.uniform_(1e-5, 1. - 1e-5) 243 | u = Variable(u) 244 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) 245 | x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-1.), max=1.) 246 | x1 = torch.clamp(torch.clamp( 247 | x[:, :, :, 1] + coeffs[:, :, :, 0] * x0, min=-1.), max=1.) 248 | x2 = torch.clamp(torch.clamp( 249 | x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1, min=-1.), max=1.) 250 | 251 | out = torch.cat([x0.view(xs[:-1] + [1]), x1.view(xs[:-1] + [1]), x2.view(xs[:-1] + [1])], dim=3) 252 | # put back in Pytorch ordering 253 | out = out.permute(0, 3, 1, 2) 254 | return out 255 | -------------------------------------------------------------------------------- /biva/utils/logging.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import matplotlib.image 5 | import numpy as np 6 | import torch 7 | from torchvision.utils import make_grid 8 | 9 | 10 | def save_img(img, path): 11 | def _scale(img): 12 | img *= 255 13 | return img.astype(np.uint8) 14 | 15 | img = _scale(img) 16 | 17 | matplotlib.image.imsave(path, img) 18 | 19 | 20 | def summary2logger(logger, summary, global_step, epoch, best=None, stats_key='loss'): 21 | """write summary to logging""" 22 | if not stats_key in summary.keys(): 23 | logger.warning('key ' + str(stats_key) + ' not int output dictionary') 24 | else: 25 | message = f'\t[{global_step} / {epoch}] ' 26 | message += ''.join([f'{k} {v:6.2f} ' for k, v in summary.get(stats_key).items()]) 27 | message += f'({summary["info"]["elapsed-time"]:.2f}s /iter)' 28 | if best is not None: 29 | message += f' (best: {best[0]:6.2f} [{best[1]} / {best[2]}])' 30 | logger.info(message) 31 | 32 | 33 | def save_model(model, eval_summary, global_step, epoch, best_elbo, logdir, key='elbo'): 34 | elbo = eval_summary['loss'][key] 35 | prev_elbo, *_ = best_elbo 36 | if elbo > prev_elbo: 37 | best_elbo = (elbo, global_step, epoch) 38 | pth = os.path.join(logdir, "model.pth") 39 | torch.save(model.state_dict(), pth) 40 | 41 | return best_elbo 42 | 43 | 44 | def load_model(model, logdir): 45 | device = next(iter(model.parameters())).device 46 | model.load_state_dict(torch.load(os.path.join(logdir, "model.pth"), map_location=device)) 47 | 48 | 49 | @torch.no_grad() 50 | def sample_model(model, likelihood, logdir, global_step=0, writer=None, N=100): 51 | # sample model 52 | x_ = model.sample_from_prior(N).get('x_') 53 | x_ = likelihood(logits=x_).sample() 54 | 55 | # make grid 56 | nrow = math.floor(math.sqrt(N)) 57 | grid = make_grid(x_, nrow=nrow) 58 | 59 | # normalize 60 | grid -= grid.min() 61 | grid /= grid.max() 62 | 63 | # log to tensorboard 64 | if writer is not None: 65 | writer.add_image('samples', grid, global_step) 66 | 67 | # save the raw image 68 | img = grid.data.permute(1, 2, 0).cpu().numpy() 69 | matplotlib.image.imsave(os.path.join(logdir, "samples.png"), img) 70 | -------------------------------------------------------------------------------- /biva/utils/ops.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | import torch 4 | 5 | 6 | def append_ellapsed_time(func): 7 | def wrapper(*args, **kwargs): 8 | start_time = time() 9 | diagnostics = func(*args, **kwargs) 10 | diagnostics['info']['elapsed-time'] = time() - start_time 11 | return diagnostics 12 | 13 | return wrapper 14 | 15 | 16 | @append_ellapsed_time 17 | def training_step(x, model, evaluator, optimizer, scheduler=None, **kwargs): 18 | optimizer.zero_grad() 19 | model.train() 20 | 21 | loss, diagnostics, output = evaluator(model, x, **kwargs) 22 | loss = loss.mean(0) 23 | 24 | loss.backward() 25 | optimizer.step() 26 | if scheduler is not None: 27 | scheduler.step() 28 | 29 | return diagnostics 30 | 31 | 32 | @torch.no_grad() 33 | @append_ellapsed_time 34 | def test_step(x, model, evaluator, **kwargs): 35 | model.eval() 36 | 37 | loss, diagnostics, output = evaluator(model, x, **kwargs) 38 | 39 | return diagnostics 40 | -------------------------------------------------------------------------------- /biva/utils/restore.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | 5 | from booster.utils import available_device 6 | from torch.distributions import Bernoulli 7 | 8 | from .discretized_mixture_logits import DiscretizedMixtureLogits 9 | from .logging import load_model 10 | from ..datasets import get_binmnist_datasets, get_cifar10_datasets 11 | from ..evaluation import VariationalInference 12 | from ..model.deepvae import DeepVae 13 | 14 | 15 | def restore_session(logdir, device='auto'): 16 | """load model from a saved session""" 17 | 18 | if logdir[-1] == '/': 19 | logdir = logdir[:-1] 20 | run_id = logdir.split('/')[-1] 21 | 22 | # load the hyperparameters and arguments 23 | hyperparameters = pickle.load(open(os.path.join(logdir, "hyperparameters.p"), "rb")) 24 | opt = json.load(open(os.path.join(logdir, "config.json"))) 25 | 26 | # instantiate the model 27 | model = DeepVae(**hyperparameters) 28 | device = available_device() if device == 'auto' else device 29 | model.to(device) 30 | 31 | # load pretrained weights 32 | load_model(model, logdir) 33 | 34 | # define likelihood and evaluator 35 | likelihood = {'cifar': DiscretizedMixtureLogits(opt['nr_mix']), 'binmnist': Bernoulli}[opt['dataset']] 36 | evaluator = VariationalInference(likelihood, iw_samples=1) 37 | 38 | # load the dataset 39 | if opt['dataset'] == 'binmnist': 40 | train_dataset, valid_dataset, test_dataset = get_binmnist_datasets(opt['data_root']) 41 | elif opt['dataset'] == 'cifar10': 42 | from torchvision.transforms import Lambda 43 | 44 | transform = Lambda(lambda x: x * 2 - 1) 45 | train_dataset, valid_dataset, test_dataset = get_cifar10_datasets(opt.data_root, transform=transform) 46 | else: 47 | raise NotImplementedError 48 | 49 | return { 50 | 'model': model, 51 | 'device': device, 52 | 'run_id': run_id, 53 | 'hyperparameters': hyperparameters, 54 | 'opt': hyperparameters, 55 | 'likelihood': likelihood, 56 | 'evaluator': evaluator, 57 | 'train_dataset': train_dataset, 58 | 'valid_dataset': valid_dataset, 59 | 'test_dataset': test_dataset, 60 | } 61 | -------------------------------------------------------------------------------- /biva/utils/utils.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | 7 | def shp_cat(shps: List[Tuple[int]], dim: int): 8 | """concatenate tensor shapes""" 9 | out = list(shps[0]) 10 | out[dim] = sum(list(s)[dim] for s in shps) 11 | return tuple(out) 12 | 13 | 14 | def detach_to_device(x, device): 15 | """detach, clone and or place on the right device""" 16 | if x is not None: 17 | if isinstance(x, torch.Tensor): 18 | return x.detach().clone().to(device) 19 | else: 20 | return torch.tensor(x, device=device, dtype=torch.float) 21 | else: 22 | return None 23 | 24 | 25 | def batch_reduce(x, reduce=torch.sum): 26 | batch_size = x.size(0) 27 | return reduce(x.view(batch_size, -1), dim=-1) 28 | 29 | 30 | def log_sum_exp(tensor, dim=-1, sum_op=torch.sum, eps: float = 1e-12, keepdim=False): 31 | """ 32 | Uses the LogSumExp (LSE) as an approximation for the sum in a log-domain. 33 | :param tensor: Tensor to compute LSE over 34 | :param dim: dimension to perform operation over 35 | :param sum_op: reductive operation to be applied, e.g. torch.sum or torch.mean 36 | :return: LSE 37 | """ 38 | max, _ = torch.max(tensor, dim=dim, keepdim=keepdim) 39 | return torch.log(sum_op(torch.exp(tensor - max), dim=dim, keepdim=keepdim) + eps) + max 40 | 41 | 42 | class LowerBoundedExponentialLR(_LRScheduler): 43 | """Set the learning rate of each parameter group to the initial lr decayed 44 | by gamma every epoch. When last_epoch=-1, sets initial lr as lr. 45 | 46 | Args: 47 | optimizer (Optimizer): Wrapped optimizer. 48 | gamma (float): Multiplicative factor of learning rate decay. 49 | lower_bound (float): lower bound for the learning rate. 50 | last_epoch (int): The index of last epoch. Default: -1. 51 | """ 52 | 53 | def __init__(self, optimizer, gamma, lower_bound, last_epoch=-1): 54 | self.gamma = gamma 55 | self.lower_bound = lower_bound 56 | super(LowerBoundedExponentialLR, self).__init__(optimizer, last_epoch) 57 | 58 | def _get_lr(self, base_lr): 59 | lr = base_lr * self.gamma ** self.last_epoch 60 | if lr < self.lower_bound: 61 | lr = self.lower_bound 62 | return lr 63 | 64 | def get_lr(self): 65 | return [self._get_lr(base_lr) 66 | for base_lr in self.base_lrs] 67 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Bernoulli 3 | 4 | from biva import DenseNormal, ConvNormal 5 | from biva import VAE, LVAE, BIVA 6 | 7 | # build a 2 layers VAE for binary images 8 | 9 | # define the stochastic layers 10 | z = [ 11 | {'N': 8, 'kernel': 5, 'block': ConvNormal}, # z1 12 | {'N': 16, 'block': DenseNormal} # z2 13 | ] 14 | 15 | # define the intermediate layers 16 | # each stage defines the configuration of the blocks for q_(z_{l} | z_{l-1}) and p_(z_{l-1} | z_{l}) 17 | # each stage is defined by a sequence of 3 resnet blocks 18 | # each block is degined by a tuple [filters, kernel, stride] 19 | stages = [ 20 | [[64, 3, 1], [64, 3, 1], [64, 3, 2]], 21 | [[64, 3, 1], [64, 3, 1], [64, 3, 2]] 22 | ] 23 | 24 | # build the model 25 | model = VAE(tensor_shp=(-1, 1, 28, 28), stages=stages, latents=z, dropout=0.5) 26 | 27 | # forward pass and data-dependent initialization 28 | x = torch.empty((8, 1, 28, 28)).uniform_().bernoulli() 29 | data = model(x) # data = {'x_' : p(x|z), z \sim q(z|x), 'kl': [kl_z1, kl_z2]} 30 | 31 | # sample from prior 32 | data = model.sample_from_prior(N=16) # data = {'x_' : p(x|z), z \sim p(z)} 33 | samples = Bernoulli(logits=data['x_']).sample() 34 | -------------------------------------------------------------------------------- /load_deepvae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | import torch 8 | from biva.utils.restore import restore_session 9 | from booster.utils import logging_sep 10 | from torch.utils.data import DataLoader 11 | from torchvision.utils import make_grid 12 | 13 | 14 | def build_and_save_grid(data, logdir, filename, N=100): 15 | nrow = math.floor(math.sqrt(N)) 16 | grid = make_grid(data, nrow=nrow) 17 | 18 | # normalize 19 | grid -= grid.min() 20 | grid /= grid.max() 21 | 22 | # save the raw image 23 | img = grid.data.permute(1, 2, 0).cpu().numpy() 24 | matplotlib.image.imsave(os.path.join(logdir, f"{filename}.png"), img) 25 | 26 | plt.figure(figsize=(8, 8)) 27 | plt.title(filename.split('/')[-1]) 28 | plt.imshow(img); 29 | plt.axis('off') 30 | plt.show() 31 | 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--logdir', default='runs/binmnist-biva-seed42', help='directory containing the training session') 35 | parser.add_argument('--bs', default=36, type=int, help='batch size') 36 | parser.add_argument('--device', default='auto', help='auto, cuda, cpu') 37 | args = parser.parse_args() 38 | 39 | session = restore_session(args.logdir, args.device) 40 | model, likelihood, evaluator, device, run_id = [session[k] for k in 41 | ['model', 'likelihood', 'evaluator', 'device', 'run_id']] 42 | 43 | logdir = os.path.join('output', run_id) 44 | if not os.path.exists(logdir): 45 | os.makedirs(logdir) 46 | 47 | dataset = DataLoader(session['test_dataset'], batch_size=args.bs, shuffle=True, pin_memory=False, num_workers=0) 48 | 49 | with torch.no_grad(): 50 | x = next(iter(dataset)).to(device) 51 | build_and_save_grid(x, logdir, "original") 52 | 53 | # display posterior samples x ~ p(x|z), z ~ q(z|x) 54 | x_ = model(x).get('x_') 55 | x_ = likelihood(logits=x_).sample() 56 | build_and_save_grid(x_, logdir, "posterior") 57 | 58 | # dislay prior samples x ~ p(x|z), z ~ p(z) 59 | x_ = model.sample_from_prior(100).get('x_') 60 | x_ = likelihood(logits=x_).sample() 61 | build_and_save_grid(x_, logdir, "prior") 62 | 63 | print(logging_sep("=")) 64 | print(f"Samples logged in {logdir}") 65 | 66 | # evaluate the likelihood on the batch of data 67 | _, diagnostics, _ = evaluator(model, x) 68 | print(logging_sep("=")) 69 | print(diagnostics) 70 | print(logging_sep("=")) 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2019.9.11 2 | cycler==0.10.0 3 | kiwisolver==1.1.0 4 | matplotlib==3.1.1 5 | numpy==1.17.3 6 | Pillow==8.3.2 7 | pyparsing==2.4.2 8 | python-dateutil==2.8.0 9 | six==1.12.0 10 | torch==1.3.0.post2 11 | torchvision==0.4.1.post2 12 | tqdm==4.36.1 13 | tensorboard==2.0.1 14 | booster-pytorch==0.0.2 15 | -------------------------------------------------------------------------------- /run_deepvae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import pickle 6 | 7 | import numpy as np 8 | import torch 9 | from biva.datasets import get_binmnist_datasets, get_cifar10_datasets 10 | from biva.evaluation import VariationalInference 11 | from biva.model import DeepVae, get_deep_vae_mnist, get_deep_vae_cifar, VaeStage, LvaeStage, BivaStage 12 | from biva.utils import LowerBoundedExponentialLR, training_step, test_step, summary2logger, save_model, load_model, \ 13 | sample_model, DiscretizedMixtureLogits 14 | from booster import Aggregator 15 | from booster.utils import EMA, logging_sep, available_device 16 | from torch.distributions import Bernoulli 17 | from torch.utils.data import DataLoader 18 | from torch.utils.tensorboard import SummaryWriter 19 | from tqdm import tqdm 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--root', default='runs/', help='directory to store training logs') 23 | parser.add_argument('--data_root', default='data/', help='directory to store the dataset') 24 | parser.add_argument('--dataset', default='binmnist', help='binmnist') 25 | parser.add_argument('--model_type', default='biva', help='model type (vae | lvae | biva)') 26 | parser.add_argument('--device', default='auto', help='auto, cuda, cpu') 27 | parser.add_argument('--num_workers', default=1, type=int, help='number of workers') 28 | parser.add_argument('--bs', default=48, type=int, help='batch size') 29 | parser.add_argument('--epochs', default=500, type=int, help='number of epochs') 30 | parser.add_argument('--lr', default=2e-3, type=float, help='base learning rate') 31 | parser.add_argument('--seed', default=42, type=int, help='random seed') 32 | parser.add_argument('--freebits', default=2.0, type=float, help='freebits per latent variable') 33 | parser.add_argument('--nr_mix', default=10, type=int, help='number of mixtures') 34 | parser.add_argument('--ema', default=0.9995, type=float, help='ema') 35 | parser.add_argument('--q_dropout', default=0.5, type=float, help='inference model dropout') 36 | parser.add_argument('--p_dropout', default=0.5, type=float, help='generative model dropout') 37 | parser.add_argument('--iw_samples', default=1000, type=int, help='number of importance weighted samples for testing') 38 | parser.add_argument('--id', default='', type=str, help='run id suffix') 39 | parser.add_argument('--no_skip', action='store_true', help='do not use skip connections') 40 | parser.add_argument('--log_var_act', default='softplus', type=str, help='activation for the log variance') 41 | parser.add_argument('--beta', default=1.0, type=float, help='Beta parameter (Beta-VAE)') 42 | 43 | opt = parser.parse_args() 44 | 45 | # set random seed, set run-id, init log directory and save config 46 | torch.manual_seed(opt.seed) 47 | np.random.seed(opt.seed) 48 | run_id = f"{opt.dataset}-{opt.model_type}-seed{opt.seed}" 49 | if len(opt.id): 50 | run_id += f"-{opt.id}" 51 | if opt.beta != 1: 52 | run_id += f"-{opt.beta}" 53 | logdir = os.path.join(opt.root, run_id) 54 | if not os.path.exists(logdir): 55 | os.makedirs(logdir) 56 | with open(os.path.join(logdir, 'config.json'), 'w') as fp: 57 | fp.write(json.dumps(vars(opt))) 58 | 59 | # define tensorboard writers 60 | train_writer = SummaryWriter(os.path.join(logdir, 'train')) 61 | valid_writer = SummaryWriter(os.path.join(logdir, 'valid')) 62 | 63 | # load data 64 | if opt.dataset == 'binmnist': 65 | train_dataset, valid_dataset, test_dataset = get_binmnist_datasets(opt.data_root) 66 | elif opt.dataset == 'cifar10': 67 | from torchvision.transforms import Lambda 68 | 69 | transform = Lambda(lambda x: x * 2 - 1) 70 | train_dataset, valid_dataset, test_dataset = get_cifar10_datasets(opt.data_root, transform=transform) 71 | else: 72 | raise NotImplementedError 73 | 74 | train_loader = DataLoader(train_dataset, batch_size=opt.bs, shuffle=True, pin_memory=False, num_workers=opt.num_workers) 75 | valid_loader = DataLoader(valid_dataset, batch_size=2 * opt.bs, shuffle=True, pin_memory=False, 76 | num_workers=opt.num_workers) 77 | test_loader = DataLoader(test_dataset, batch_size=2 * opt.bs, shuffle=True, pin_memory=False, 78 | num_workers=opt.num_workers) 79 | tensor_shp = (-1, *train_dataset[0].shape) 80 | 81 | # define likelihood 82 | likelihood = {'cifar10': DiscretizedMixtureLogits(opt.nr_mix), 'binmnist': Bernoulli}[opt.dataset] 83 | 84 | # define model 85 | if 'cifar' in opt.dataset: 86 | stages, latents = get_deep_vae_cifar() 87 | features_out = 10 * opt.nr_mix 88 | else: 89 | stages, latents = get_deep_vae_mnist() 90 | features_out = tensor_shp[1] 91 | 92 | Stage = {'vae': VaeStage, 'lvae': LvaeStage, 'biva': BivaStage}[opt.model_type] 93 | log_var_act = {'none': None, 'softplus': torch.nn.Softplus, 'tanh': torch.nn.Tanh}[opt.log_var_act] 94 | hyperparameters = { 95 | 'Stage': Stage, 96 | 'tensor_shp': tensor_shp, 97 | 'stages': stages, 98 | 'latents': latents, 99 | 'nonlinearity': 'elu', 100 | 'q_dropout': opt.q_dropout, 101 | 'p_dropout': opt.p_dropout, 102 | 'type': opt.model_type, 103 | 'features_out': features_out, 104 | 'no_skip': opt.no_skip, 105 | 'log_var_act': log_var_act 106 | } 107 | # save hyper parameters for easy loading 108 | pickle.dump(hyperparameters, open(os.path.join(logdir, "hyperparameters.p"), "wb")) 109 | 110 | # instantiate the model and move to target device 111 | model = DeepVae(**hyperparameters) 112 | device = available_device() if opt.device == 'auto' else opt.device 113 | model.to(device) 114 | 115 | # define the evaluator 116 | evaluator = VariationalInference(likelihood, iw_samples=1) 117 | 118 | # define evaluation model with Exponential Moving Average 119 | ema = EMA(model, opt.ema) 120 | 121 | # data dependent init for weight normalization (automatically done during the first forward pass) 122 | with torch.no_grad(): 123 | model.train() 124 | x = next(iter(train_loader)).to(device) 125 | model(x) 126 | 127 | # print stages 128 | print(logging_sep("=") + "\nGenerative model:\n" + logging_sep("-")) 129 | for i, (convs, z) in reversed(list(enumerate(zip(stages, latents)))): 130 | print(f"Stage #{i + 1}") 131 | print("Stochastic layer:", z) 132 | print("Deterministic block:", convs) 133 | print(logging_sep("=")) 134 | 135 | # define freebits 136 | n_latents = len(latents) 137 | if opt.model_type == 'biva': 138 | n_latents = 2 * n_latents - 1 139 | freebits = [opt.freebits] * n_latents 140 | 141 | # optimizer 142 | optimizer = torch.optim.Adamax(model.parameters(), lr=opt.lr, betas=(0.9, 0.999,)) 143 | scheduler = LowerBoundedExponentialLR(optimizer, 0.999999, 0.0001) 144 | 145 | # logging utils 146 | kwargs = {'beta': opt.beta, 'freebits': freebits} 147 | best_elbo = (-1e20, 0, 0) 148 | global_step = 1 149 | train_agg = Aggregator() 150 | val_agg = Aggregator() 151 | logging.basicConfig(level=logging.INFO, 152 | format='%(asctime)s %(name)-4s %(levelname)-4s %(message)s', 153 | datefmt='%m-%d %H:%M', 154 | handlers=[logging.FileHandler(os.path.join(logdir, 'run.log')), 155 | logging.StreamHandler()]) 156 | train_logger = logging.getLogger('train') 157 | eval_logger = logging.getLogger('eval') 158 | M_parameters = (sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6) 159 | logging.getLogger(run_id).info(f'# Total Number of Parameters: {M_parameters:.3f}M') 160 | print(logging_sep() + f"\nLogging directory: {logdir}\n" + logging_sep()) 161 | 162 | # init sample 163 | sample_model(ema.model, likelihood, logdir, writer=valid_writer, global_step=global_step, N=100) 164 | 165 | # run 166 | for epoch in range(1, opt.epochs + 1): 167 | 168 | # training 169 | train_agg.initialize() 170 | for x in tqdm(train_loader, desc='train epoch'): 171 | x = x.to(device) 172 | diagnostics = training_step(x, model, evaluator, optimizer, scheduler, **kwargs) 173 | train_agg.update(diagnostics) 174 | ema.update() 175 | global_step += 1 176 | train_summary = train_agg.data.to('cpu') 177 | 178 | # evaluation 179 | val_agg.initialize() 180 | for x in tqdm(valid_loader, desc='valid epoch'): 181 | x = x.to(device) 182 | diagnostics = test_step(x, ema.model, evaluator, **kwargs) 183 | val_agg.update(diagnostics) 184 | eval_summary = val_agg.data.to('cpu') 185 | 186 | # keep best model 187 | best_elbo = save_model(ema.model, eval_summary, global_step, epoch, best_elbo, logdir) 188 | 189 | # logging 190 | summary2logger(train_logger, train_summary, global_step, epoch) 191 | summary2logger(eval_logger, eval_summary, global_step, epoch, best_elbo) 192 | 193 | # tensorboard logging 194 | train_summary.log(train_writer, global_step) 195 | eval_summary.log(valid_writer, global_step) 196 | 197 | # sample model 198 | sample_model(ema.model, likelihood, logdir, writer=valid_writer, global_step=global_step, N=100) 199 | 200 | # load best model 201 | load_model(ema.model, logdir) 202 | 203 | # sample model 204 | sample_model(ema.model, likelihood, logdir, N=100) 205 | 206 | # final test 207 | iw_evaluator = VariationalInference(likelihood, iw_samples=opt.iw_samples) 208 | test_agg = Aggregator() 209 | test_logger = logging.getLogger('test') 210 | test_logger.info(f"best elbo at step {best_elbo[1]}, epoch {best_elbo[2]}: {best_elbo[0]:.3f} nats") 211 | 212 | test_agg.initialize() 213 | for x in tqdm(test_loader, desc='iw test epoch'): 214 | x = x.to(device) 215 | diagnostics = test_step(x, ema.model, iw_evaluator, **kwargs) 216 | test_agg.update(diagnostics) 217 | test_summary = test_agg.data.to('cpu') 218 | 219 | summary2logger(test_logger, test_summary, best_elbo[1], best_elbo[2], None) 220 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name='biva-pytorch', 8 | version='0.1.4', 9 | author="Valentin Lievin", 10 | author_email="valentin.lievin@gmail.com", 11 | description="Official PyTorch BIVA implementation (BIVA: A Very Deep Hierarchy of Latent Variables for Generative Modeling)", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/vlievin/biva-pytorch", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ], 21 | install_requires=[ 22 | 'torch', 23 | 'tqdm', 24 | 'numpy', 25 | 'matplotlib', 26 | 'tensorboard', 27 | 'booster-pytorch' 28 | ], 29 | ) 30 | --------------------------------------------------------------------------------