├── .gitignore ├── 511_supplement.pdf ├── LICENSE ├── README.md ├── bnaf.py ├── data ├── bsds300.py ├── gas.py ├── generate2d.py ├── hepmass.py ├── miniboone.py └── power.py ├── density_estimation.py ├── download_datasets.sh ├── optim ├── adam.py ├── adamax.py └── lr_scheduler.py └── toy2d.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # data and runs 107 | checkpoint/ 108 | data/BSDS300/ 109 | data/gas/ 110 | data/hepmass/ 111 | data/miniboone/ 112 | data/power/ 113 | tensorboard/ 114 | 115 | -------------------------------------------------------------------------------- /511_supplement.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicola-decao/BNAF/da43f564aa335a5a32922118316c70bb5c3d861c/511_supplement.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Nicola De Cao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BNAF 2 | Pytorch implementation of Block Neural Autoregressive Flow based on our paper: 3 | > De Cao Nicola, Titov Ivan and Aziz Wilker, [Block Neural Autoregressive Flow](http://arxiv.org/abs/1904.04676) (2019) 4 | 5 | ## Requirements 6 | * **``python>=3.6``** (it will probably work on older versions but I have not tested on them) 7 | * **``pytorch>=1.0.0``** 8 | 9 | Optional for visualization and plotting: ``numpy``, ``matplotlib`` and ``tensorboardX`` 10 | 11 | ## Structure 12 | * [bnaf.py](https://github.com/nicola-decao/BNAF/blob/master/bnaf.py): Implementation of Block Neural Normalzing Flow. 13 | * [toy2d.py](https://github.com/nicola-decao/BNAF/blob/master/toy2d.py): Experiments of 2d toy task (density estimation and energy matching). 14 | * [density_estimation.py](https://github.com/nicola-decao/BNAF/blob/master/density_estimation.py): Experiments on density estimation on real datasets. 15 | * [optim](https://github.com/nicola-decao/BNAF/tree/master/optim): A custom extension of `torch.optim.Adam` and `torch.optim.Adamax` with Polyak averaging. A custom extension of `torch.optim.lr_scheduler.ReduceLROnPlateau` with callbacks. 16 | * [data](https://github.com/nicola-decao/BNAF/tree/master/data): Data classes to handle the real datasets. 17 | 18 | ## Usage 19 | Below, example commands are given for running experiments. 20 | 21 | #### Download datasets 22 | Run the following command to download the datasets: 23 | ``` 24 | ./download_datasets.sh 25 | ``` 26 | 27 | #### Run 2D toy density estimation 28 | This example runs density estimation on the `8 Gaussians` dataset using 1 flow of BNAF with 2 layers and 100 hidden units (`50 * 2` since the data dimensionality is 2). 29 | ``` 30 | python toy2d.py --dataset 8gaussians \ # which dataset to use 31 | --experiment density2d \ # which experiment to run 32 | --flows 1 \ # BNAF flows to concatenate 33 | --layers 2 \ # layers for each flow of BNAF 34 | --hidden_dim 50 \ # hidden units per dimension for each hidden layer 35 | --save # save the model after training 36 | --savefig # save the density plot on disk 37 | ``` 38 | 39 | ![Imgur](https://i.imgur.com/DWVGsyn.jpg) 40 | 41 | #### Run 2D toy energy matching 42 | This example runs energy matching on the `t4` function using 1 flow of BNAF with 2 layers and 100 hidden units (`50 * 2` since the data dimensionality is 2). 43 | ``` 44 | python toy2d.py --dataset t4 \ # which dataset to use 45 | --experiment energy2d \ # which experiment to run 46 | --flows 1 \ # BNAF flows to concatenate 47 | --layers 2 \ # layers for each flow of BNAF 48 | --hidden_dim 50 \ # hidden units per dimension for each hidden layer 49 | --save # save the model after training 50 | --savefig # save the density plot on disk 51 | ``` 52 | 53 | ![Imgur](https://i.imgur.com/o1QR3XO.jpg) 54 | 55 | #### Run real dataset density estimation 56 | This example runs density estimation on the `MINIBOONE` dataset using 5 flows of BNAF with 0 layers. 57 | ``` 58 | python density_estimation.py --dataset miniboone \ # which dataset to use 59 | --flows 5 \ # BNAF flows to concatenate 60 | --layers 0 \ # layers for each flow of BNAF 61 | --hidden_dim 10 \ # hidden units per dimension for each hidden layer 62 | --save # save the model after training 63 | ``` 64 | 65 | ## Citation 66 | ``` 67 | De Cao Nicola, Titov Ivan, Aziz Wilker, 68 | Block Neural Autoregressive Flow, 69 | 35th Conference on Uncertainty in Artificial Intelligence (UAI19) (2019). 70 | ``` 71 | 72 | BibTeX format: 73 | ``` 74 | @article{bnaf19, 75 | title={Block Neural Autoregressive Flow}, 76 | author={De Cao, Nicola and 77 | Titov, Ivan and 78 | Aziz, Wilker}, 79 | journal={35th Conference on Uncertainty in Artificial Intelligence (UAI19)}, 80 | year={2019} 81 | } 82 | ``` 83 | 84 | ## Feedback 85 | For questions and comments, feel free to contact [Nicola De Cao](mailto:nicola.decao@gmail.com). 86 | 87 | ## License 88 | MIT 89 | -------------------------------------------------------------------------------- /bnaf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | class Sequential(torch.nn.Sequential): 6 | """ 7 | Class that extends ``torch.nn.Sequential`` for computing the output of 8 | the function alongside with the log-det-Jacobian of such transformation. 9 | """ 10 | 11 | def forward(self, inputs: torch.Tensor): 12 | """ 13 | Parameters 14 | ---------- 15 | inputs : ``torch.Tensor``, required. 16 | The input tensor. 17 | Returns 18 | ------- 19 | The output tensor and the log-det-Jacobian of this transformation. 20 | """ 21 | 22 | log_det_jacobian = 0.0 23 | for i, module in enumerate(self._modules.values()): 24 | inputs, log_det_jacobian_ = module(inputs) 25 | log_det_jacobian = log_det_jacobian + log_det_jacobian_ 26 | return inputs, log_det_jacobian 27 | 28 | 29 | class BNAF(torch.nn.Sequential): 30 | """ 31 | Class that extends ``torch.nn.Sequential`` for constructing a Block Neural 32 | Normalizing Flow. 33 | """ 34 | 35 | def __init__(self, *args, res: str = None): 36 | """ 37 | Parameters 38 | ---------- 39 | *args : ``Iterable[torch.nn.Module]``, required. 40 | The modules to use. 41 | res : ``str``, optional (default = None). 42 | Which kind of residual connection to use. ``res = None`` is no residual 43 | connection, ``res = 'normal'`` is ``x + f(x)`` and ``res = 'gated'`` is 44 | ``a * x + (1 - a) * f(x)`` where ``a`` is a learnable parameter. 45 | """ 46 | 47 | super(BNAF, self).__init__(*args) 48 | 49 | self.res = res 50 | 51 | if res == "gated": 52 | self.gate = torch.nn.Parameter(torch.nn.init.normal_(torch.Tensor(1))) 53 | 54 | def forward(self, inputs: torch.Tensor): 55 | """ 56 | Parameters 57 | ---------- 58 | inputs : ``torch.Tensor``, required. 59 | The input tensor. 60 | Returns 61 | ------- 62 | The output tensor and the log-det-Jacobian of this transformation. 63 | """ 64 | 65 | outputs = inputs 66 | grad = None 67 | 68 | for module in self._modules.values(): 69 | outputs, grad = module(outputs, grad) 70 | 71 | grad = grad if len(grad.shape) == 4 else grad.view(grad.shape + [1, 1]) 72 | 73 | assert inputs.shape[-1] == outputs.shape[-1] 74 | 75 | if self.res == "normal": 76 | return inputs + outputs, torch.nn.functional.softplus(grad.squeeze()).sum( 77 | -1 78 | ) 79 | elif self.res == "gated": 80 | return self.gate.sigmoid() * outputs + (1 - self.gate.sigmoid()) * inputs, ( 81 | torch.nn.functional.softplus(grad.squeeze() + self.gate) 82 | - torch.nn.functional.softplus(self.gate) 83 | ).sum(-1) 84 | else: 85 | return outputs, grad.squeeze().sum(-1) 86 | 87 | def _get_name(self): 88 | return "BNAF(res={})".format(self.res) 89 | 90 | 91 | class Permutation(torch.nn.Module): 92 | """ 93 | Module that outputs a permutation of its input. 94 | """ 95 | 96 | def __init__(self, in_features: int, p: list = None): 97 | """ 98 | Parameters 99 | ---------- 100 | in_features : ``int``, required. 101 | The number of input features. 102 | p : ``list`` or ``str``, optional (default = None) 103 | The list of indeces that indicate the permutation. When ``p`` is not a 104 | list, if ``p = 'flip'``the tensor is reversed, if ``p = None`` a random 105 | permutation is applied. 106 | """ 107 | 108 | super(Permutation, self).__init__() 109 | 110 | self.in_features = in_features 111 | 112 | if p is None: 113 | self.p = np.random.permutation(in_features) 114 | elif p == "flip": 115 | self.p = list(reversed(range(in_features))) 116 | else: 117 | self.p = p 118 | 119 | def forward(self, inputs: torch.Tensor): 120 | """ 121 | Parameters 122 | ---------- 123 | inputs : ``torch.Tensor``, required. 124 | The input tensor. 125 | Returns 126 | ------- 127 | The permuted tensor and the log-det-Jacobian of this permutation. 128 | """ 129 | 130 | return inputs[:, self.p], 0 131 | 132 | def __repr__(self): 133 | return "Permutation(in_features={}, p={})".format(self.in_features, self.p) 134 | 135 | 136 | class MaskedWeight(torch.nn.Module): 137 | """ 138 | Module that implements a linear layer with block matrices with positive diagonal blocks. 139 | Moreover, it uses Weight Normalization (https://arxiv.org/abs/1602.07868) for stability. 140 | """ 141 | 142 | def __init__( 143 | self, in_features: int, out_features: int, dim: int, bias: bool = True 144 | ): 145 | """ 146 | Parameters 147 | ---------- 148 | in_features : ``int``, required. 149 | The number of input features per each dimension ``dim``. 150 | out_features : ``int``, required. 151 | The number of output features per each dimension ``dim``. 152 | dim : ``int``, required. 153 | The number of dimensions of the input of the flow. 154 | bias : ``bool``, optional (default = True). 155 | Whether to add a parametrizable bias. 156 | """ 157 | 158 | super(MaskedWeight, self).__init__() 159 | self.in_features, self.out_features, self.dim = in_features, out_features, dim 160 | 161 | weight = torch.zeros(out_features, in_features) 162 | for i in range(dim): 163 | weight[ 164 | i * out_features // dim : (i + 1) * out_features // dim, 165 | 0 : (i + 1) * in_features // dim, 166 | ] = torch.nn.init.xavier_uniform_( 167 | torch.Tensor(out_features // dim, (i + 1) * in_features // dim) 168 | ) 169 | 170 | self._weight = torch.nn.Parameter(weight) 171 | self._diag_weight = torch.nn.Parameter( 172 | torch.nn.init.uniform_(torch.Tensor(out_features, 1)).log() 173 | ) 174 | 175 | self.bias = ( 176 | torch.nn.Parameter( 177 | torch.nn.init.uniform_( 178 | torch.Tensor(out_features), 179 | -1 / math.sqrt(out_features), 180 | 1 / math.sqrt(out_features), 181 | ) 182 | ) 183 | if bias 184 | else 0 185 | ) 186 | 187 | mask_d = torch.zeros_like(weight) 188 | for i in range(dim): 189 | mask_d[ 190 | i * (out_features // dim) : (i + 1) * (out_features // dim), 191 | i * (in_features // dim) : (i + 1) * (in_features // dim), 192 | ] = 1 193 | 194 | self.register_buffer("mask_d", mask_d) 195 | 196 | mask_o = torch.ones_like(weight) 197 | for i in range(dim): 198 | mask_o[ 199 | i * (out_features // dim) : (i + 1) * (out_features // dim), 200 | i * (in_features // dim) :, 201 | ] = 0 202 | 203 | self.register_buffer("mask_o", mask_o) 204 | 205 | def get_weights(self): 206 | """ 207 | Computes the weight matrix using masks and weight normalization. 208 | It also compute the log diagonal blocks of it. 209 | """ 210 | 211 | w = torch.exp(self._weight) * self.mask_d + self._weight * self.mask_o 212 | 213 | w_squared_norm = (w ** 2).sum(-1, keepdim=True) 214 | 215 | w = self._diag_weight.exp() * w / w_squared_norm.sqrt() 216 | 217 | wpl = self._diag_weight + self._weight - 0.5 * torch.log(w_squared_norm) 218 | 219 | return w.t(), wpl.t()[self.mask_d.bool().t()].view( 220 | self.dim, self.in_features // self.dim, self.out_features // self.dim 221 | ) 222 | 223 | def forward(self, inputs, grad: torch.Tensor = None): 224 | """ 225 | Parameters 226 | ---------- 227 | inputs : ``torch.Tensor``, required. 228 | The input tensor. 229 | grad : ``torch.Tensor``, optional (default = None). 230 | The log diagonal block of the partial Jacobian of previous transformations. 231 | Returns 232 | ------- 233 | The output tensor and the log diagonal blocks of the partial log-Jacobian of previous 234 | transformations combined with this transformation. 235 | """ 236 | 237 | w, wpl = self.get_weights() 238 | 239 | g = wpl.transpose(-2, -1).unsqueeze(0).repeat(inputs.shape[0], 1, 1, 1) 240 | 241 | return ( 242 | inputs.matmul(w) + self.bias, 243 | torch.logsumexp(g.unsqueeze(-2) + grad.transpose(-2, -1).unsqueeze(-3), -1) 244 | if grad is not None 245 | else g, 246 | ) 247 | 248 | def __repr__(self): 249 | return "MaskedWeight(in_features={}, out_features={}, dim={}, bias={})".format( 250 | self.in_features, 251 | self.out_features, 252 | self.dim, 253 | not isinstance(self.bias, int), 254 | ) 255 | 256 | 257 | class Tanh(torch.nn.Tanh): 258 | """ 259 | Class that extends ``torch.nn.Tanh`` additionally computing the log diagonal 260 | blocks of the Jacobian. 261 | """ 262 | 263 | def forward(self, inputs, grad: torch.Tensor = None): 264 | """ 265 | Parameters 266 | ---------- 267 | inputs : ``torch.Tensor``, required. 268 | The input tensor. 269 | grad : ``torch.Tensor``, optional (default = None). 270 | The log diagonal blocks of the partial Jacobian of previous transformations. 271 | Returns 272 | ------- 273 | The output tensor and the log diagonal blocks of the partial log-Jacobian of previous 274 | transformations combined with this transformation. 275 | """ 276 | 277 | g = -2 * (inputs - math.log(2) + torch.nn.functional.softplus(-2 * inputs)) 278 | return ( 279 | torch.tanh(inputs), 280 | (g.view(grad.shape) + grad) if grad is not None else g, 281 | ) 282 | -------------------------------------------------------------------------------- /data/bsds300.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | 4 | 5 | class BSDS300: 6 | """ 7 | A dataset of patches from BSDS300. 8 | """ 9 | 10 | class Data: 11 | """ 12 | Constructs the dataset. 13 | """ 14 | 15 | def __init__(self, data): 16 | 17 | self.x = data[:] 18 | self.N = self.x.shape[0] 19 | 20 | def __init__(self, file): 21 | 22 | # load dataset 23 | f = h5py.File(file, "r") 24 | 25 | self.trn = self.Data(f["train"]) 26 | self.val = self.Data(f["validation"]) 27 | self.tst = self.Data(f["test"]) 28 | 29 | self.n_dims = self.trn.x.shape[1] 30 | self.image_size = [int(np.sqrt(self.n_dims + 1))] * 2 31 | 32 | f.close() 33 | -------------------------------------------------------------------------------- /data/gas.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | 5 | class GAS: 6 | class Data: 7 | def __init__(self, data): 8 | 9 | self.x = data.astype(np.float32) 10 | self.N = self.x.shape[0] 11 | 12 | def __init__(self, file): 13 | 14 | trn, val, tst = load_data_and_clean_and_split(file) 15 | 16 | self.trn = self.Data(trn) 17 | self.val = self.Data(val) 18 | self.tst = self.Data(tst) 19 | 20 | self.n_dims = self.trn.x.shape[1] 21 | 22 | 23 | def load_data(file): 24 | 25 | data = pd.read_pickle(file) 26 | # data = pd.read_pickle(file).sample(frac=0.25) 27 | # data.to_pickle(file) 28 | data.drop("Meth", axis=1, inplace=True) 29 | data.drop("Eth", axis=1, inplace=True) 30 | data.drop("Time", axis=1, inplace=True) 31 | return data 32 | 33 | 34 | def get_correlation_numbers(data): 35 | C = data.corr() 36 | A = C > 0.98 37 | B = A.as_matrix().sum(axis=1) 38 | return B 39 | 40 | 41 | def load_data_and_clean(file): 42 | 43 | data = load_data(file) 44 | B = get_correlation_numbers(data) 45 | 46 | while np.any(B > 1): 47 | col_to_remove = np.where(B > 1)[0][0] 48 | col_name = data.columns[col_to_remove] 49 | data.drop(col_name, axis=1, inplace=True) 50 | B = get_correlation_numbers(data) 51 | # print(data.corr()) 52 | data = (data - data.mean()) / data.std() 53 | 54 | return data 55 | 56 | 57 | def load_data_and_clean_and_split(file): 58 | 59 | data = load_data_and_clean(file).as_matrix() 60 | N_test = int(0.1 * data.shape[0]) 61 | data_test = data[-N_test:] 62 | data_train = data[0:-N_test] 63 | N_validate = int(0.1 * data_train.shape[0]) 64 | data_validate = data_train[-N_validate:] 65 | data_train = data_train[0:-N_validate] 66 | 67 | return data_train, data_validate, data_test 68 | -------------------------------------------------------------------------------- /data/generate2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sklearn 3 | import numpy as np 4 | 5 | 6 | def sample2d(data, batch_size=200): 7 | 8 | rng = np.random.RandomState() 9 | 10 | if data == "8gaussians": 11 | scale = 4.0 12 | centers = [ 13 | (1, 0), 14 | (-1, 0), 15 | (0, 1), 16 | (0, -1), 17 | (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)), 18 | (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)), 19 | (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)), 20 | (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)), 21 | ] 22 | centers = [(scale * x, scale * y) for x, y in centers] 23 | 24 | dataset = [] 25 | for i in range(batch_size): 26 | point = rng.randn(2) * 0.5 27 | idx = rng.randint(8) 28 | center = centers[idx] 29 | point[0] += center[0] 30 | point[1] += center[1] 31 | dataset.append(point) 32 | dataset = np.array(dataset, dtype="float32") 33 | dataset /= 1.414 34 | return dataset 35 | 36 | elif data == "2spirals": 37 | n = np.sqrt(np.random.rand(batch_size // 2, 1)) * 540 * (2 * np.pi) / 360 38 | d1x = -np.cos(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 39 | d1y = np.sin(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 40 | x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3 41 | x += np.random.randn(*x.shape) * 0.1 42 | return x 43 | 44 | elif data == "checkerboard": 45 | x1 = np.random.rand(batch_size) * 4 - 2 46 | x2_ = np.random.rand(batch_size) - np.random.randint(0, 2, batch_size) * 2 47 | x2 = x2_ + (np.floor(x1) % 2) 48 | return np.concatenate([x1[:, None], x2[:, None]], 1) * 2 49 | 50 | else: 51 | raise RuntimeError 52 | 53 | 54 | def energy2d(data, z): 55 | 56 | if data == "t1": 57 | return U1(z) 58 | elif data == "t2": 59 | return U2(z) 60 | elif data == "t3": 61 | return U3(z) 62 | elif data == "t4": 63 | return U4(z) 64 | else: 65 | raise RuntimeError 66 | 67 | 68 | def w1(z): 69 | return torch.sin(2 * np.pi * z[:, 0] / 4) 70 | 71 | 72 | def w2(z): 73 | return 3 * torch.exp(-0.5 * ((z[:, 0] - 1) / 0.6) ** 2) 74 | 75 | 76 | def w3(z): 77 | return 3 * torch.sigmoid((z[:, 0] - 1) / 0.3) 78 | 79 | 80 | def U1(z): 81 | z_norm = torch.norm(z, 2, 1) 82 | add1 = 0.5 * ((z_norm - 2) / 0.4) ** 2 83 | add2 = -torch.log( 84 | torch.exp(-0.5 * ((z[:, 0] - 2) / 0.6) ** 2) 85 | + torch.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2) 86 | + 1e-9 87 | ) 88 | 89 | return add1 + add2 90 | 91 | 92 | def U2(z): 93 | return 0.5 * ((z[:, 1] - w1(z)) / 0.4) ** 2 94 | 95 | 96 | def U3(z): 97 | in1 = torch.exp(-0.5 * ((z[:, 1] - w1(z)) / 0.35) ** 2) 98 | in2 = torch.exp(-0.5 * ((z[:, 1] - w1(z) + w2(z)) / 0.35) ** 2) 99 | return -torch.log(in1 + in2 + 1e-9) 100 | 101 | 102 | def U4(z): 103 | in1 = torch.exp(-0.5 * ((z[:, 1] - w1(z)) / 0.4) ** 2) 104 | in2 = torch.exp(-0.5 * ((z[:, 1] - w1(z) + w3(z)) / 0.35) ** 2) 105 | return -torch.log(in1 + in2 + 1e-9) 106 | -------------------------------------------------------------------------------- /data/hepmass.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from collections import Counter 4 | from os.path import join 5 | 6 | 7 | class HEPMASS: 8 | """ 9 | The HEPMASS data set. 10 | http://archive.ics.uci.edu/ml/datasets/HEPMASS 11 | """ 12 | 13 | class Data: 14 | def __init__(self, data): 15 | 16 | self.x = data.astype(np.float32) 17 | self.N = self.x.shape[0] 18 | 19 | def __init__(self, path): 20 | 21 | trn, val, tst = load_data_no_discrete_normalised_as_array(path) 22 | 23 | self.trn = self.Data(trn) 24 | self.val = self.Data(val) 25 | self.tst = self.Data(tst) 26 | 27 | self.n_dims = self.trn.x.shape[1] 28 | 29 | 30 | def load_data(path): 31 | 32 | data_train = pd.read_csv( 33 | filepath_or_buffer=join(path, "1000_train.csv"), index_col=False 34 | ) 35 | data_test = pd.read_csv( 36 | filepath_or_buffer=join(path, "1000_test.csv"), index_col=False 37 | ) 38 | 39 | return data_train, data_test 40 | 41 | 42 | def load_data_no_discrete(path): 43 | """ 44 | Loads the positive class examples from the first 10 percent of the dataset. 45 | """ 46 | data_train, data_test = load_data(path) 47 | 48 | # Gets rid of any background noise examples i.e. class label 0. 49 | data_train = data_train[data_train[data_train.columns[0]] == 1] 50 | data_train = data_train.drop(data_train.columns[0], axis=1) 51 | data_test = data_test[data_test[data_test.columns[0]] == 1] 52 | data_test = data_test.drop(data_test.columns[0], axis=1) 53 | # Because the data set is messed up! 54 | data_test = data_test.drop(data_test.columns[-1], axis=1) 55 | 56 | return data_train, data_test 57 | 58 | 59 | def load_data_no_discrete_normalised(path): 60 | 61 | data_train, data_test = load_data_no_discrete(path) 62 | mu = data_train.mean() 63 | s = data_train.std() 64 | data_train = (data_train - mu) / s 65 | data_test = (data_test - mu) / s 66 | 67 | return data_train, data_test 68 | 69 | 70 | def load_data_no_discrete_normalised_as_array(path): 71 | 72 | data_train, data_test = load_data_no_discrete_normalised(path) 73 | data_train, data_test = data_train.as_matrix(), data_test.as_matrix() 74 | 75 | i = 0 76 | # Remove any features that have too many re-occurring real values. 77 | features_to_remove = [] 78 | for feature in data_train.T: 79 | c = Counter(feature) 80 | max_count = np.array([v for k, v in sorted(c.items())])[0] 81 | if max_count > 5: 82 | features_to_remove.append(i) 83 | i += 1 84 | data_train = data_train[ 85 | :, 86 | np.array( 87 | [i for i in range(data_train.shape[1]) if i not in features_to_remove] 88 | ), 89 | ] 90 | data_test = data_test[ 91 | :, 92 | np.array([i for i in range(data_test.shape[1]) if i not in features_to_remove]), 93 | ] 94 | 95 | N = data_train.shape[0] 96 | N_validate = int(N * 0.1) 97 | data_validate = data_train[-N_validate:] 98 | data_train = data_train[0:-N_validate] 99 | 100 | return data_train, data_validate, data_test 101 | -------------------------------------------------------------------------------- /data/miniboone.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class MINIBOONE: 5 | class Data: 6 | def __init__(self, data): 7 | 8 | self.x = data.astype(np.float32) 9 | self.N = self.x.shape[0] 10 | 11 | def __init__(self, file): 12 | 13 | trn, val, tst = load_data_normalised(file) 14 | 15 | self.trn = self.Data(trn) 16 | self.val = self.Data(val) 17 | self.tst = self.Data(tst) 18 | 19 | self.n_dims = self.trn.x.shape[1] 20 | 21 | 22 | def load_data(root_path): 23 | # NOTE: To remember how the pre-processing was done. 24 | # data = pd.read_csv(root_path, names=[str(x) for x in range(50)], delim_whitespace=True) 25 | # print data.head() 26 | # data = data.as_matrix() 27 | # # Remove some random outliers 28 | # indices = (data[:, 0] < -100) 29 | # data = data[~indices] 30 | # 31 | # i = 0 32 | # # Remove any features that have too many re-occuring real values. 33 | # features_to_remove = [] 34 | # for feature in data.T: 35 | # c = Counter(feature) 36 | # max_count = np.array([v for k, v in sorted(c.iteritems())])[0] 37 | # if max_count > 5: 38 | # features_to_remove.append(i) 39 | # i += 1 40 | # data = data[:, np.array([i for i in range(data.shape[1]) if i not in features_to_remove])] 41 | # np.save("~/data/miniboone/data.npy", data) 42 | 43 | data = np.load(root_path) 44 | N_test = int(0.1 * data.shape[0]) 45 | data_test = data[-N_test:] 46 | data = data[0:-N_test] 47 | N_validate = int(0.1 * data.shape[0]) 48 | data_validate = data[-N_validate:] 49 | data_train = data[0:-N_validate] 50 | 51 | return data_train, data_validate, data_test 52 | 53 | 54 | def load_data_normalised(root_path): 55 | 56 | data_train, data_validate, data_test = load_data(root_path) 57 | data = np.vstack((data_train, data_validate)) 58 | mu = data.mean(axis=0) 59 | s = data.std(axis=0) 60 | data_train = (data_train - mu) / s 61 | data_validate = (data_validate - mu) / s 62 | data_test = (data_test - mu) / s 63 | 64 | return data_train, data_validate, data_test 65 | -------------------------------------------------------------------------------- /data/power.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class POWER: 5 | class Data: 6 | def __init__(self, data): 7 | 8 | self.x = data.astype(np.float32) 9 | self.N = self.x.shape[0] 10 | 11 | def __init__(self, file): 12 | 13 | trn, val, tst = load_data_normalised(file) 14 | 15 | self.trn = self.Data(trn) 16 | self.val = self.Data(val) 17 | self.tst = self.Data(tst) 18 | 19 | self.n_dims = self.trn.x.shape[1] 20 | 21 | 22 | def load_data(file): 23 | return np.load(file) 24 | 25 | 26 | def load_data_split_with_noise(file): 27 | 28 | rng = np.random.RandomState(42) 29 | 30 | data = load_data(file) 31 | rng.shuffle(data) 32 | N = data.shape[0] 33 | 34 | data = np.delete(data, 3, axis=1) 35 | data = np.delete(data, 1, axis=1) 36 | ############################ 37 | # Add noise 38 | ############################ 39 | # global_intensity_noise = 0.1*rng.rand(N, 1) 40 | voltage_noise = 0.01 * rng.rand(N, 1) 41 | # grp_noise = 0.001*rng.rand(N, 1) 42 | gap_noise = 0.001 * rng.rand(N, 1) 43 | sm_noise = rng.rand(N, 3) 44 | time_noise = np.zeros((N, 1)) 45 | # noise = np.hstack((gap_noise, grp_noise, voltage_noise, global_intensity_noise, sm_noise, time_noise)) 46 | # noise = np.hstack((gap_noise, grp_noise, voltage_noise, sm_noise, time_noise)) 47 | noise = np.hstack((gap_noise, voltage_noise, sm_noise, time_noise)) 48 | data = data + noise 49 | 50 | N_test = int(0.1 * data.shape[0]) 51 | data_test = data[-N_test:] 52 | data = data[0:-N_test] 53 | N_validate = int(0.1 * data.shape[0]) 54 | data_validate = data[-N_validate:] 55 | data_train = data[0:-N_validate] 56 | 57 | return data_train, data_validate, data_test 58 | 59 | 60 | def load_data_normalised(file): 61 | 62 | data_train, data_validate, data_test = load_data_split_with_noise(file) 63 | data = np.vstack((data_train, data_validate)) 64 | mu = data.mean(axis=0) 65 | s = data.std(axis=0) 66 | data_train = (data_train - mu) / s 67 | data_validate = (data_validate - mu) / s 68 | data_test = (data_test - mu) / s 69 | 70 | return data_train, data_validate, data_test 71 | -------------------------------------------------------------------------------- /density_estimation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import pprint 5 | import datetime 6 | import torch 7 | from torch.utils import data 8 | from bnaf import * 9 | from tqdm import tqdm 10 | from optim.adam import Adam 11 | from optim.lr_scheduler import ReduceLROnPlateau 12 | 13 | from data.gas import GAS 14 | from data.bsds300 import BSDS300 15 | from data.hepmass import HEPMASS 16 | from data.miniboone import MINIBOONE 17 | from data.power import POWER 18 | 19 | NAF_PARAMS = { 20 | "power": (414213, 828258), 21 | "gas": (401741, 803226), 22 | "hepmass": (9272743, 18544268), 23 | "miniboone": (7487321, 14970256), 24 | "bsds300": (36759591, 73510236), 25 | } 26 | 27 | 28 | def load_dataset(args): 29 | if args.dataset == "gas": 30 | dataset = GAS("data/gas/ethylene_CO.pickle") 31 | elif args.dataset == "bsds300": 32 | dataset = BSDS300("data/BSDS300/BSDS300.hdf5") 33 | elif args.dataset == "hepmass": 34 | dataset = HEPMASS("data/hepmass") 35 | elif args.dataset == "miniboone": 36 | dataset = MINIBOONE("data/miniboone/data.npy") 37 | elif args.dataset == "power": 38 | dataset = POWER("data/power/data.npy") 39 | else: 40 | raise RuntimeError() 41 | 42 | dataset_train = torch.utils.data.TensorDataset( 43 | torch.from_numpy(dataset.trn.x).float().to(args.device) 44 | ) 45 | data_loader_train = torch.utils.data.DataLoader( 46 | dataset_train, batch_size=args.batch_dim, shuffle=True 47 | ) 48 | 49 | dataset_valid = torch.utils.data.TensorDataset( 50 | torch.from_numpy(dataset.val.x).float().to(args.device) 51 | ) 52 | data_loader_valid = torch.utils.data.DataLoader( 53 | dataset_valid, batch_size=args.batch_dim, shuffle=False 54 | ) 55 | 56 | dataset_test = torch.utils.data.TensorDataset( 57 | torch.from_numpy(dataset.tst.x).float().to(args.device) 58 | ) 59 | data_loader_test = torch.utils.data.DataLoader( 60 | dataset_test, batch_size=args.batch_dim, shuffle=False 61 | ) 62 | 63 | args.n_dims = dataset.n_dims 64 | 65 | return data_loader_train, data_loader_valid, data_loader_test 66 | 67 | 68 | def create_model(args, verbose=False): 69 | 70 | flows = [] 71 | for f in range(args.flows): 72 | layers = [] 73 | for _ in range(args.layers - 1): 74 | layers.append( 75 | MaskedWeight( 76 | args.n_dims * args.hidden_dim, 77 | args.n_dims * args.hidden_dim, 78 | dim=args.n_dims, 79 | ) 80 | ) 81 | layers.append(Tanh()) 82 | 83 | flows.append( 84 | BNAF( 85 | *( 86 | [ 87 | MaskedWeight( 88 | args.n_dims, args.n_dims * args.hidden_dim, dim=args.n_dims 89 | ), 90 | Tanh(), 91 | ] 92 | + layers 93 | + [ 94 | MaskedWeight( 95 | args.n_dims * args.hidden_dim, args.n_dims, dim=args.n_dims 96 | ) 97 | ] 98 | ), 99 | res=args.residual if f < args.flows - 1 else None 100 | ) 101 | ) 102 | 103 | if f < args.flows - 1: 104 | flows.append(Permutation(args.n_dims, "flip")) 105 | 106 | model = Sequential(*flows).to(args.device) 107 | params = sum( 108 | (p != 0).sum() if len(p.shape) > 1 else torch.tensor(p.shape).item() 109 | for p in model.parameters() 110 | ).item() 111 | 112 | if verbose: 113 | print("{}".format(model)) 114 | print( 115 | "Parameters={}, NAF/BNAF={:.2f}/{:.2f}, n_dims={}".format( 116 | params, 117 | NAF_PARAMS[args.dataset][0] / params, 118 | NAF_PARAMS[args.dataset][1] / params, 119 | args.n_dims, 120 | ) 121 | ) 122 | 123 | if args.save and not args.load: 124 | with open(os.path.join(args.load or args.path, "results.txt"), "a") as f: 125 | print( 126 | "Parameters={}, NAF/BNAF={:.2f}/{:.2f}, n_dims={}".format( 127 | params, 128 | NAF_PARAMS[args.dataset][0] / params, 129 | NAF_PARAMS[args.dataset][1] / params, 130 | args.n_dims, 131 | ), 132 | file=f, 133 | ) 134 | 135 | return model 136 | 137 | 138 | def save_model(model, optimizer, epoch, args): 139 | def f(): 140 | if args.save: 141 | print("Saving model..") 142 | torch.save( 143 | { 144 | "model": model.state_dict(), 145 | "optimizer": optimizer.state_dict(), 146 | "epoch": epoch, 147 | }, 148 | os.path.join(args.load or args.path, "checkpoint.pt"), 149 | ) 150 | 151 | return f 152 | 153 | 154 | def load_model(model, optimizer, args, load_start_epoch=False): 155 | def f(): 156 | print("Loading model..") 157 | checkpoint = torch.load(os.path.join(args.load or args.path, "checkpoint.pt")) 158 | model.load_state_dict(checkpoint["model"]) 159 | optimizer.load_state_dict(checkpoint["optimizer"]) 160 | 161 | if load_start_epoch: 162 | args.start_epoch = checkpoint["epoch"] 163 | 164 | return f 165 | 166 | 167 | def compute_log_p_x(model, x_mb): 168 | y_mb, log_diag_j_mb = model(x_mb) 169 | log_p_y_mb = ( 170 | torch.distributions.Normal(torch.zeros_like(y_mb), torch.ones_like(y_mb)) 171 | .log_prob(y_mb) 172 | .sum(-1) 173 | ) 174 | return log_p_y_mb + log_diag_j_mb 175 | 176 | 177 | def train( 178 | model, 179 | optimizer, 180 | scheduler, 181 | data_loader_train, 182 | data_loader_valid, 183 | data_loader_test, 184 | args, 185 | ): 186 | 187 | if args.tensorboard: 188 | from tensorboardX import SummaryWriter 189 | 190 | writer = SummaryWriter(os.path.join(args.tensorboard, args.load or args.path)) 191 | 192 | epoch = args.start_epoch 193 | for epoch in range(args.start_epoch, args.start_epoch + args.epochs): 194 | 195 | t = tqdm(data_loader_train, smoothing=0, ncols=80) 196 | train_loss = [] 197 | 198 | for (x_mb,) in t: 199 | loss = -compute_log_p_x(model, x_mb).mean() 200 | 201 | loss.backward() 202 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_norm) 203 | 204 | optimizer.step() 205 | optimizer.zero_grad() 206 | 207 | t.set_postfix(loss="{:.2f}".format(loss.item()), refresh=False) 208 | train_loss.append(loss) 209 | 210 | train_loss = torch.stack(train_loss).mean() 211 | optimizer.swap() 212 | validation_loss = -torch.stack( 213 | [ 214 | compute_log_p_x(model, x_mb).mean().detach() 215 | for x_mb, in data_loader_valid 216 | ], 217 | -1, 218 | ).mean() 219 | optimizer.swap() 220 | 221 | print( 222 | "Epoch {:3}/{:3} -- train_loss: {:4.3f} -- validation_loss: {:4.3f}".format( 223 | epoch + 1, 224 | args.start_epoch + args.epochs, 225 | train_loss.item(), 226 | validation_loss.item(), 227 | ) 228 | ) 229 | 230 | stop = scheduler.step( 231 | validation_loss, 232 | callback_best=save_model(model, optimizer, epoch + 1, args), 233 | callback_reduce=load_model(model, optimizer, args), 234 | ) 235 | 236 | if args.tensorboard: 237 | writer.add_scalar("lr", optimizer.param_groups[0]["lr"], epoch + 1) 238 | writer.add_scalar("loss/validation", validation_loss.item(), epoch + 1) 239 | writer.add_scalar("loss/train", train_loss.item(), epoch + 1) 240 | 241 | if stop: 242 | break 243 | 244 | load_model(model, optimizer, args)() 245 | optimizer.swap() 246 | validation_loss = -torch.stack( 247 | [compute_log_p_x(model, x_mb).mean().detach() for x_mb, in data_loader_valid], 248 | -1, 249 | ).mean() 250 | test_loss = -torch.stack( 251 | [compute_log_p_x(model, x_mb).mean().detach() for x_mb, in data_loader_test], -1 252 | ).mean() 253 | 254 | print("###### Stop training after {} epochs!".format(epoch + 1)) 255 | print("Validation loss: {:4.3f}".format(validation_loss.item())) 256 | print("Test loss: {:4.3f}".format(test_loss.item())) 257 | 258 | if args.save: 259 | with open(os.path.join(args.load or args.path, "results.txt"), "a") as f: 260 | print("###### Stop training after {} epochs!".format(epoch + 1), file=f) 261 | print("Validation loss: {:4.3f}".format(validation_loss.item()), file=f) 262 | print("Test loss: {:4.3f}".format(test_loss.item()), file=f) 263 | 264 | 265 | def main(): 266 | parser = argparse.ArgumentParser() 267 | parser.add_argument("--device", type=str, default="cuda:0") 268 | parser.add_argument( 269 | "--dataset", 270 | type=str, 271 | default="miniboone", 272 | choices=["gas", "bsds300", "hepmass", "miniboone", "power"], 273 | ) 274 | 275 | parser.add_argument("--learning_rate", type=float, default=1e-2) 276 | parser.add_argument("--batch_dim", type=int, default=200) 277 | parser.add_argument("--clip_norm", type=float, default=0.1) 278 | parser.add_argument("--epochs", type=int, default=1000) 279 | 280 | parser.add_argument("--patience", type=int, default=20) 281 | parser.add_argument("--cooldown", type=int, default=10) 282 | parser.add_argument("--early_stopping", type=int, default=100) 283 | parser.add_argument("--decay", type=float, default=0.5) 284 | parser.add_argument("--min_lr", type=float, default=5e-4) 285 | parser.add_argument("--polyak", type=float, default=0.998) 286 | 287 | parser.add_argument("--flows", type=int, default=5) 288 | parser.add_argument("--layers", type=int, default=1) 289 | parser.add_argument("--hidden_dim", type=int, default=10) 290 | parser.add_argument( 291 | "--residual", type=str, default="gated", choices=[None, "normal", "gated"] 292 | ) 293 | 294 | parser.add_argument("--expname", type=str, default="") 295 | parser.add_argument("--load", type=str, default=None) 296 | parser.add_argument("--save", action="store_true") 297 | parser.add_argument("--tensorboard", type=str, default="tensorboard") 298 | 299 | args = parser.parse_args() 300 | 301 | print("Arguments:") 302 | pprint.pprint(args.__dict__) 303 | 304 | args.path = os.path.join( 305 | "checkpoint", 306 | "{}{}_layers{}_h{}_flows{}{}_{}".format( 307 | args.expname + ("_" if args.expname != "" else ""), 308 | args.dataset, 309 | args.layers, 310 | args.hidden_dim, 311 | args.flows, 312 | "_" + args.residual if args.residual else "", 313 | str(datetime.datetime.now())[:-7].replace(" ", "-").replace(":", "-"), 314 | ), 315 | ) 316 | 317 | print("Loading dataset..") 318 | data_loader_train, data_loader_valid, data_loader_test = load_dataset(args) 319 | 320 | if args.save and not args.load: 321 | print("Creating directory experiment..") 322 | os.mkdir(args.path) 323 | with open(os.path.join(args.path, "args.json"), "w") as f: 324 | json.dump(args.__dict__, f, indent=4, sort_keys=True) 325 | 326 | print("Creating BNAF model..") 327 | model = create_model(args, verbose=True) 328 | 329 | print("Creating optimizer..") 330 | optimizer = Adam( 331 | model.parameters(), lr=args.learning_rate, amsgrad=True, polyak=args.polyak 332 | ) 333 | 334 | print("Creating scheduler..") 335 | scheduler = ReduceLROnPlateau( 336 | optimizer, 337 | factor=args.decay, 338 | patience=args.patience, 339 | cooldown=args.cooldown, 340 | min_lr=args.min_lr, 341 | verbose=True, 342 | early_stopping=args.early_stopping, 343 | threshold_mode="abs", 344 | ) 345 | 346 | args.start_epoch = 0 347 | if args.load: 348 | load_model(model, optimizer, args, load_start_epoch=True)() 349 | 350 | print("Training..") 351 | train( 352 | model, 353 | optimizer, 354 | scheduler, 355 | data_loader_train, 356 | data_loader_valid, 357 | data_loader_test, 358 | args, 359 | ) 360 | 361 | 362 | if __name__ == "__main__": 363 | main() 364 | -------------------------------------------------------------------------------- /download_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget https://zenodo.org/record/1161203/files/data.tar.gz 4 | tar -zxvf data.tar.gz 5 | rm -rf data/mnist/ 6 | rm -rf data/cifar10/ 7 | rm data.tar.gz 8 | mkdir checkpoint 9 | -------------------------------------------------------------------------------- /optim/adam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | class Adam(torch.optim.Optimizer): 6 | def __init__( 7 | self, 8 | params, 9 | lr=1e-3, 10 | betas=(0.9, 0.999), 11 | eps=1e-8, 12 | weight_decay=0, 13 | amsgrad=False, 14 | polyak=0.0, 15 | ): 16 | if not 0.0 <= lr: 17 | raise ValueError("Invalid learning rate: {}".format(lr)) 18 | if not 0.0 <= eps: 19 | raise ValueError("Invalid epsilon value: {}".format(eps)) 20 | if not 0.0 <= betas[0] < 1.0: 21 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 22 | if not 0.0 <= betas[1] < 1.0: 23 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 24 | if not 0.0 <= polyak <= 1.0: 25 | raise ValueError("Invalid polyak decay term: {}".format(polyak)) 26 | 27 | defaults = dict( 28 | lr=lr, 29 | betas=betas, 30 | eps=eps, 31 | weight_decay=weight_decay, 32 | amsgrad=amsgrad, 33 | polyak=polyak, 34 | ) 35 | super(Adam, self).__init__(params, defaults) 36 | 37 | def __setstate__(self, state): 38 | super(Adam, self).__setstate__(state) 39 | for group in self.param_groups: 40 | group.setdefault("amsgrad", False) 41 | 42 | def step(self, closure=None): 43 | """Performs a single optimization step. 44 | 45 | Arguments: 46 | closure (callable, optional): A closure that reevaluates the model 47 | and returns the loss. 48 | """ 49 | loss = None 50 | if closure is not None: 51 | loss = closure() 52 | 53 | for group in self.param_groups: 54 | for p in group["params"]: 55 | if p.grad is None: 56 | continue 57 | grad = p.grad.data 58 | if grad.is_sparse: 59 | raise RuntimeError( 60 | "Adam does not support sparse gradients, please consider SparseAdam instead" 61 | ) 62 | amsgrad = group["amsgrad"] 63 | 64 | state = self.state[p] 65 | 66 | # State initialization 67 | if len(state) == 0: 68 | state["step"] = 0 69 | # Exponential moving average of gradient values 70 | state["exp_avg"] = torch.zeros_like(p.data) 71 | # Exponential moving average of squared gradient values 72 | state["exp_avg_sq"] = torch.zeros_like(p.data) 73 | # Exponential moving average of param 74 | state["exp_avg_param"] = torch.zeros_like(p.data) 75 | if amsgrad: 76 | # Maintains max of all exp. moving avg. of sq. grad. values 77 | state["max_exp_avg_sq"] = torch.zeros_like(p.data) 78 | 79 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 80 | if amsgrad: 81 | max_exp_avg_sq = state["max_exp_avg_sq"] 82 | beta1, beta2 = group["betas"] 83 | 84 | state["step"] += 1 85 | 86 | if group["weight_decay"] != 0: 87 | grad.add_(group["weight_decay"], p.data) 88 | 89 | # Decay the first and second moment running average coefficient 90 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 91 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 92 | if amsgrad: 93 | # Maintains the maximum of all 2nd moment running avg. till now 94 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 95 | # Use the max. for normalizing running avg. of gradient 96 | denom = max_exp_avg_sq.sqrt().add_(group["eps"]) 97 | else: 98 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 99 | 100 | bias_correction1 = 1 - beta1 ** state["step"] 101 | bias_correction2 = 1 - beta2 ** state["step"] 102 | step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 103 | 104 | p.data.addcdiv_(-step_size, exp_avg, denom) 105 | 106 | polyak = self.defaults["polyak"] 107 | state["exp_avg_param"] = ( 108 | polyak * state["exp_avg_param"] + (1 - polyak) * p.data 109 | ) 110 | 111 | return loss 112 | 113 | def swap(self): 114 | """ 115 | Swapping the running average of params and the current params for saving parameters using polyak averaging 116 | """ 117 | for group in self.param_groups: 118 | for p in group["params"]: 119 | state = self.state[p] 120 | new = p.data 121 | p.data = state["exp_avg_param"] 122 | state["exp_avg_param"] = new 123 | 124 | def substitute(self): 125 | for group in self.param_groups: 126 | for p in group["params"]: 127 | p.data = self.state[p]["exp_avg_param"] 128 | -------------------------------------------------------------------------------- /optim/adamax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Adamax(torch.optim.Optimizer): 5 | def __init__( 6 | self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, polyak=0 7 | ): 8 | if not 0.0 <= lr: 9 | raise ValueError("Invalid learning rate: {}".format(lr)) 10 | if not 0.0 <= eps: 11 | raise ValueError("Invalid epsilon value: {}".format(eps)) 12 | if not 0.0 <= betas[0] < 1.0: 13 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 14 | if not 0.0 <= betas[1] < 1.0: 15 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 16 | if not 0.0 <= weight_decay: 17 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 18 | if not 0.0 <= polyak <= 1.0: 19 | raise ValueError("Invalid polyak decay term: {}".format(polyak)) 20 | 21 | defaults = dict( 22 | lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, polyak=polyak 23 | ) 24 | super(Adamax, self).__init__(params, defaults) 25 | 26 | def step(self, closure=None): 27 | 28 | loss = None 29 | if closure is not None: 30 | loss = closure() 31 | 32 | for group in self.param_groups: 33 | for p in group["params"]: 34 | if p.grad is None: 35 | continue 36 | grad = p.grad.data 37 | if grad.is_sparse: 38 | raise RuntimeError("Adamax does not support sparse gradients") 39 | state = self.state[p] 40 | 41 | # State initialization 42 | if len(state) == 0: 43 | state["step"] = 0 44 | state["exp_avg"] = torch.zeros_like(p.data) 45 | state["exp_inf"] = torch.zeros_like(p.data) 46 | # Exponential moving average of param 47 | state["exp_avg_param"] = torch.zeros_like(p.data) 48 | 49 | exp_avg, exp_inf = state["exp_avg"], state["exp_inf"] 50 | beta1, beta2 = group["betas"] 51 | eps = group["eps"] 52 | 53 | state["step"] += 1 54 | 55 | if group["weight_decay"] != 0: 56 | grad = grad.add(group["weight_decay"], p.data) 57 | 58 | # Update biased first moment estimate. 59 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 60 | # Update the exponentially weighted infinity norm. 61 | norm_buf = torch.cat( 62 | [ 63 | exp_inf.mul_(beta2).unsqueeze(0), 64 | grad.abs().add_(eps).unsqueeze_(0), 65 | ], 66 | 0, 67 | ) 68 | torch.max( 69 | norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long()) 70 | ) 71 | 72 | bias_correction = 1 - beta1 ** state["step"] 73 | clr = group["lr"] / bias_correction 74 | 75 | p.data.addcdiv_(-clr, exp_avg, exp_inf) 76 | 77 | polyak = self.defaults["polyak"] 78 | state["exp_avg_param"] = ( 79 | polyak * state["exp_avg_param"] + (1 - polyak) * p.data 80 | ) 81 | 82 | return loss 83 | 84 | def swap(self): 85 | """ 86 | Swapping the running average of params and the current params for saving parameters using polyak averaging 87 | """ 88 | for group in self.param_groups: 89 | for p in group["params"]: 90 | state = self.state[p] 91 | new = p.data 92 | p.data = state["exp_avg_param"] 93 | state["exp_avg_param"] = new 94 | 95 | def substitute(self): 96 | for group in self.param_groups: 97 | for p in group["params"]: 98 | p.data = self.state[p]["exp_avg_param"] 99 | -------------------------------------------------------------------------------- /optim/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): 5 | def __init__(self, *args, early_stopping=None, **kwargs): 6 | super().__init__(*args, **kwargs) 7 | self.early_stopping = early_stopping 8 | self.early_stopping_counter = 0 9 | 10 | def step(self, metrics, epoch=None, callback_best=None, callback_reduce=None): 11 | current = metrics 12 | if epoch is None: 13 | epoch = self.last_epoch = self.last_epoch + 1 14 | self.last_epoch = epoch 15 | 16 | if self.is_better(current, self.best): 17 | self.best = current 18 | self.num_bad_epochs = 0 19 | self.early_stopping_counter = 0 20 | if callback_best is not None: 21 | callback_best() 22 | else: 23 | self.num_bad_epochs += 1 24 | self.early_stopping_counter += 1 25 | 26 | if self.in_cooldown: 27 | self.cooldown_counter -= 1 28 | self.num_bad_epochs = 0 # ignore any bad epochs in cooldown 29 | 30 | if self.num_bad_epochs > self.patience: 31 | if callback_reduce is not None: 32 | callback_reduce() 33 | self._reduce_lr(epoch) 34 | self.cooldown_counter = self.cooldown 35 | self.num_bad_epochs = 0 36 | 37 | return self.early_stopping_counter == self.early_stopping 38 | -------------------------------------------------------------------------------- /toy2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import pprint 5 | import datetime 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import torch 9 | from torch.utils import data 10 | from bnaf import * 11 | from tqdm import trange 12 | from data.generate2d import sample2d, energy2d 13 | 14 | 15 | def create_model(args, verbose=False): 16 | 17 | flows = [] 18 | for f in range(args.flows): 19 | layers = [] 20 | for _ in range(args.layers - 1): 21 | layers.append(MaskedWeight(2 * args.hidden_dim, 2 * args.hidden_dim, dim=2)) 22 | layers.append(Tanh()) 23 | 24 | flows.append( 25 | BNAF( 26 | *( 27 | [MaskedWeight(2, 2 * args.hidden_dim, dim=2), Tanh()] 28 | + layers 29 | + [MaskedWeight(2 * args.hidden_dim, 2, dim=2)] 30 | ), 31 | res="gated" if f < args.flows - 1 else False 32 | ) 33 | ) 34 | 35 | if f < args.flows - 1: 36 | flows.append(Permutation(2, "flip")) 37 | 38 | model = Sequential(*flows).to(args.device) 39 | 40 | if verbose: 41 | print("{}".format(model)) 42 | print( 43 | "Parameters={}, n_dims={}".format( 44 | sum( 45 | (p != 0).sum() if len(p.shape) > 1 else torch.tensor(p.shape).item() 46 | for p in model.parameters() 47 | ), 48 | 2, 49 | ) 50 | ) 51 | 52 | return model 53 | 54 | 55 | def compute_log_p_x(model, x_mb): 56 | y_mb, log_diag_j_mb = model(x_mb) 57 | log_p_y_mb = ( 58 | torch.distributions.Normal(torch.zeros_like(y_mb), torch.ones_like(y_mb)) 59 | .log_prob(y_mb) 60 | .sum(-1) 61 | ) 62 | return log_p_y_mb + log_diag_j_mb 63 | 64 | 65 | def train_density2d(model, optimizer, scheduler, args): 66 | iterator = trange(args.steps, smoothing=0, dynamic_ncols=True) 67 | for epoch in iterator: 68 | 69 | x_mb = ( 70 | torch.from_numpy(sample2d(args.dataset, args.batch_dim)) 71 | .float() 72 | .to(args.device) 73 | ) 74 | 75 | loss = -compute_log_p_x(model, x_mb).mean() 76 | 77 | loss.backward() 78 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_norm) 79 | 80 | optimizer.step() 81 | optimizer.zero_grad() 82 | 83 | scheduler.step(loss) 84 | 85 | iterator.set_postfix( 86 | loss="{:.2f}".format(loss.data.cpu().numpy()), refresh=False 87 | ) 88 | 89 | 90 | def compute_kl(model, args): 91 | d_mb = torch.distributions.Normal( 92 | torch.zeros((args.batch_dim, 2)).to(args.device), 93 | torch.ones((args.batch_dim, 2)).to(args.device), 94 | ) 95 | y_mb = d_mb.sample() 96 | x_mb, log_diag_j_mb = model(y_mb) 97 | log_p_y_mb = d_mb.log_prob(y_mb).sum(-1) 98 | return ( 99 | log_p_y_mb 100 | - log_diag_j_mb 101 | + energy2d(args.dataset, x_mb) 102 | + (torch.relu(x_mb.abs() - 6) ** 2).sum(-1) 103 | ) 104 | 105 | 106 | def train_energy2d(model, optimizer, scheduler, args): 107 | iterator = trange(args.steps, smoothing=0, dynamic_ncols=True) 108 | for epoch in iterator: 109 | 110 | loss = compute_kl(model, args).mean() 111 | 112 | loss.backward() 113 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_norm) 114 | 115 | optimizer.step() 116 | optimizer.zero_grad() 117 | 118 | scheduler.step(loss) 119 | 120 | iterator.set_postfix( 121 | loss="{:.2f}".format(loss.data.cpu().numpy()), refresh=False 122 | ) 123 | 124 | 125 | def load(model, optimizer, path): 126 | print("Loading dataset..") 127 | checkpoint = torch.load(path) 128 | model.load_state_dict(checkpoint["model"]) 129 | optimizer.load_state_dict(checkpoint["optimizer"]) 130 | 131 | 132 | def save(model, optimizer, path): 133 | print("Saving model..") 134 | torch.save({"model": model.state_dict(), "optimizer": optimizer.state_dict()}, path) 135 | 136 | 137 | def plot_density2d(model, args, limit=4, step=0.01): 138 | 139 | grid = torch.Tensor( 140 | [ 141 | [a, b] 142 | for a in np.arange(-limit, limit, step) 143 | for b in np.arange(-limit, limit, step) 144 | ] 145 | ) 146 | grid_dataset = torch.utils.data.TensorDataset(grid.to(args.device)) 147 | grid_data_loader = torch.utils.data.DataLoader( 148 | grid_dataset, batch_size=10000, shuffle=False 149 | ) 150 | 151 | prob = torch.cat( 152 | [ 153 | torch.exp(compute_log_p_x(model, x_mb)).detach() 154 | for x_mb, in grid_data_loader 155 | ], 156 | 0, 157 | ) 158 | 159 | prob = prob.view(int(2 * limit / step), int(2 * limit / step)).t() 160 | 161 | if args.reduce_extreme: 162 | prob = prob.clamp(max=prob.mean() + 3 * prob.std()) 163 | 164 | plt.figure(figsize=(12, 12)) 165 | plt.imshow(prob.cpu().data.numpy(), extent=(-limit, limit, -limit, limit)) 166 | plt.axis("off") 167 | plt.subplots_adjust(left=0, right=1, bottom=0, top=1) 168 | 169 | if args.savefig: 170 | plt.savefig( 171 | os.path.join( 172 | args.load or args.path, "{}.jpg".format(datetime.datetime.now()) 173 | ) 174 | ) 175 | else: 176 | plt.show() 177 | 178 | 179 | def plot_energy2d(model, args, limit=4, step=0.05, resolution=(10000, 10000)): 180 | 181 | a = np.hstack( 182 | [ 183 | model(torch.randn(resolution[0], 2).to(args.device))[0] 184 | .t() 185 | .cpu() 186 | .data.numpy() 187 | for _ in trange(resolution[1]) 188 | ] 189 | ) 190 | 191 | H, _, _ = np.histogram2d( 192 | a[0], 193 | a[1], 194 | bins=(np.arange(-limit, limit, step), np.arange(-limit, limit, step)), 195 | ) 196 | 197 | plt.figure(figsize=(12, 12)) 198 | plt.imshow(H.T, interpolation="gaussian") 199 | plt.axis("off") 200 | plt.subplots_adjust(left=0, right=1, bottom=0, top=1) 201 | 202 | if args.savefig: 203 | plt.savefig( 204 | os.path.join( 205 | args.load or args.path, "{}.jpg".format(datetime.datetime.now()) 206 | ) 207 | ) 208 | else: 209 | plt.show() 210 | 211 | 212 | def main(): 213 | parser = argparse.ArgumentParser() 214 | parser.add_argument("--device", type=str, default="cuda:0") 215 | parser.add_argument( 216 | "--dataset", 217 | type=str, 218 | default="8gaussians", 219 | choices=["8gaussians", "2spirals", "checkerboard", "t1", "t2", "t3", "t4"], 220 | ) 221 | parser.add_argument( 222 | "--experiment", type=str, default="density2d", choices=["density2d", "energy2d"] 223 | ) 224 | 225 | parser.add_argument("--learning_rate", type=float, default=1e-1) 226 | parser.add_argument("--batch_dim", type=int, default=200) 227 | parser.add_argument("--clip_norm", type=float, default=0.1) 228 | parser.add_argument("--steps", type=int, default=20000) 229 | 230 | parser.add_argument("--patience", type=int, default=2000) 231 | parser.add_argument("--decay", type=float, default=0.5) 232 | 233 | parser.add_argument("--flows", type=int, default=1) 234 | parser.add_argument("--layers", type=int, default=3) 235 | parser.add_argument("--hidden_dim", type=int, default=50) 236 | 237 | parser.add_argument("--expname", type=str, default="") 238 | parser.add_argument("--load", type=str, default=None) 239 | parser.add_argument("--save", action="store_true") 240 | parser.add_argument("--savefig", action="store_true") 241 | parser.add_argument("--reduce_extreme", action="store_true") 242 | 243 | args = parser.parse_args() 244 | 245 | print("Arguments:") 246 | pprint.pprint(args.__dict__) 247 | 248 | args.path = os.path.join( 249 | "checkpoint", 250 | "{}{}_layers{}_h{}_flows{}_{}".format( 251 | args.expname + ("_" if args.expname != "" else ""), 252 | args.dataset, 253 | args.layers, 254 | args.hidden_dim, 255 | args.flows, 256 | str(datetime.datetime.now())[:-7].replace(" ", "-").replace(":", "-"), 257 | ), 258 | ) 259 | 260 | if (args.save or args.savefig) and not args.load: 261 | print("Creating directory experiment..") 262 | os.mkdir(args.path) 263 | with open(os.path.join(args.path, "args.json"), "w") as f: 264 | json.dump(args.__dict__, f, indent=4, sort_keys=True) 265 | 266 | print("Creating BNAF model..") 267 | model = create_model(args, verbose=True) 268 | 269 | print("Creating optimizer..") 270 | optimizer = torch.optim.Adam( 271 | model.parameters(), lr=args.learning_rate, amsgrad=True 272 | ) 273 | 274 | print("Creating scheduler..") 275 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 276 | optimizer, 277 | factor=args.decay, 278 | patience=args.patience, 279 | min_lr=5e-4, 280 | verbose=True, 281 | threshold_mode="abs", 282 | ) 283 | 284 | if args.load: 285 | load(model, optimizer, os.path.join(args.load, "checkpoint.pt")) 286 | 287 | print("Training..") 288 | if args.experiment == "density2d": 289 | train_density2d(model, optimizer, scheduler, args) 290 | elif args.experiment == "energy2d": 291 | train_energy2d(model, optimizer, scheduler, args) 292 | 293 | if args.save: 294 | print("Saving..") 295 | save(model, optimizer, os.path.join(args.load or args.path, "checkpoint.pt")) 296 | 297 | print("Plotting..") 298 | if args.experiment == "density2d": 299 | plot_density2d(model, args) 300 | elif args.experiment == "energy2d": 301 | plot_energy2d(model, args) 302 | 303 | 304 | if __name__ == "__main__": 305 | main() 306 | --------------------------------------------------------------------------------