├── .gitignore ├── LICENSE ├── README.md ├── assets ├── 1-D.png ├── math.png └── theory.png ├── lib ├── datasets.py ├── implicit_flow.py ├── layers │ ├── __init__.py │ ├── act_norm.py │ ├── base │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── lipschitz.py │ │ ├── mixed_lipschitz.py │ │ └── utils.py │ ├── broyden.py │ ├── container.py │ ├── coupling.py │ ├── elemwise.py │ ├── glow.py │ ├── implicit_block.py │ ├── iresblock.py │ ├── mask_utils.py │ ├── normalization.py │ └── squeeze.py ├── lr_scheduler.py ├── optimizers.py ├── resflow.py ├── tabular.py ├── toy_data.py ├── utils.py └── visualize_flow.py ├── preprocessing ├── convert_to_pth.py ├── create_imagenet_benchmark_datasets.py └── extract_celeba_from_tfrecords.py ├── qualitative_samples.py ├── run_cifar10.sh ├── run_classification.sh ├── run_tabular.sh ├── run_toy.sh ├── train_classification.py ├── train_img.py ├── train_tabular.py └── train_toy.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *__pycache__* 3 | data/* 4 | pretrained_models 5 | experiments 6 | .vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Cheng Lu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implicit Normalizing Flows (ICLR 2021 Spotlight)[[arxiv](https://arxiv.org/abs/2103.09527)][[slides](https://docs.google.com/presentation/d/1reCHEBjy9ygJ0bM_jvItEGU1jHDn7NfYBeC7FBqK-1Y/edit?usp=sharing)] 2 | 3 | This repository contains Pytorch implementation of experiments from the paper [Implicit Normalizing Flows](https://arxiv.org/abs/2103.09527). The implementation is based on [Residual Flows](https://github.com/rtqichen/residual-flows). 4 | 5 |

6 | 7 |

8 | 9 | Implicit Normalizing Flows generalize normalizing flows by allowing the invertible mapping to be **implicitly** defined by the roots of an equation F(z, x) = 0. Building on [Residual Flows](https://arxiv.org/abs/1906.02735), we propose: 10 | 11 | + A **unique** and **invertible** mapping defined by an equation of the latent variable and the observed variable. 12 | + A **more powerful function space** than Residual Flows, which relaxing the Lipschitz constraints of Residual Flows. 13 | + A **scalable algorithm** for generative modeling. 14 | 15 | As an example, in 1-D case, the function space of ImpFlows contains all strongly monotonic differentiable functions, i.e. 16 |

17 | 18 |

19 | 20 | A 1-D function fitting example: 21 | 22 |

23 | 24 |

25 | 26 | ## Requirements 27 | 28 | - PyTorch 1.4, torchvision 0.5 29 | - Python 3.6+ 30 | 31 | ## Density Estimation Experiments 32 | 33 | ***NOTE***: By default, O(1)-memory gradients are enabled. However, the logged bits/dim during training will not be an actual estimate of bits/dim but whatever scalar was used to generate the unbiased gradients. If you want to check the actual bits/dim for training (and have sufficient GPU memory), set `--neumann-grad=False`. Note however that the memory cost can stochastically vary during training if this flag is `False`. 34 | 35 | Toy 2-D data: 36 | ``` 37 | bash run_toy.sh 38 | ``` 39 | 40 | CIFAR10: 41 | ``` 42 | bash run_cifar10.sh 43 | ``` 44 | 45 | Tabular Data: 46 | ``` 47 | bash run_tabular.sh 48 | ``` 49 | 50 | ## BibTeX 51 | ``` 52 | To be done after ICLR 2021. 53 | ``` 54 | -------------------------------------------------------------------------------- /assets/1-D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/implicit-normalizing-flows/d8babe48bdea7b1ba98be698bef2e954ac3817ee/assets/1-D.png -------------------------------------------------------------------------------- /assets/math.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/implicit-normalizing-flows/d8babe48bdea7b1ba98be698bef2e954ac3817ee/assets/math.png -------------------------------------------------------------------------------- /assets/theory.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/implicit-normalizing-flows/d8babe48bdea7b1ba98be698bef2e954ac3817ee/assets/theory.png -------------------------------------------------------------------------------- /lib/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.datasets as vdsets 3 | 4 | 5 | class Dataset(object): 6 | 7 | def __init__(self, loc, transform=None, in_mem=True): 8 | self.in_mem = in_mem 9 | self.dataset = torch.load(loc) 10 | if in_mem: self.dataset = self.dataset.float().div(255) 11 | self.transform = transform 12 | 13 | def __len__(self): 14 | return self.dataset.size(0) 15 | 16 | @property 17 | def ndim(self): 18 | return self.dataset.size(1) 19 | 20 | def __getitem__(self, index): 21 | x = self.dataset[index] 22 | if not self.in_mem: x = x.float().div(255) 23 | x = self.transform(x) if self.transform is not None else x 24 | return x, 0 25 | 26 | 27 | class MNIST(object): 28 | 29 | def __init__(self, dataroot, train=True, transform=None): 30 | self.mnist = vdsets.MNIST(dataroot, train=train, download=True, transform=transform) 31 | 32 | def __len__(self): 33 | return len(self.mnist) 34 | 35 | @property 36 | def ndim(self): 37 | return 1 38 | 39 | def __getitem__(self, index): 40 | return self.mnist[index] 41 | 42 | 43 | class CIFAR10(object): 44 | 45 | def __init__(self, dataroot, train=True, transform=None): 46 | self.cifar10 = vdsets.CIFAR10(dataroot, train=train, download=True, transform=transform) 47 | 48 | def __len__(self): 49 | return len(self.cifar10) 50 | 51 | @property 52 | def ndim(self): 53 | return 3 54 | 55 | def __getitem__(self, index): 56 | return self.cifar10[index] 57 | 58 | 59 | class CelebA5bit(object): 60 | 61 | LOC = 'data/celebahq64_5bit/celeba_full_64x64_5bit.pth' 62 | 63 | def __init__(self, train=True, transform=None): 64 | self.dataset = torch.load(self.LOC).float().div(31) 65 | if not train: 66 | self.dataset = self.dataset[:5000] 67 | self.transform = transform 68 | 69 | def __len__(self): 70 | return self.dataset.size(0) 71 | 72 | @property 73 | def ndim(self): 74 | return self.dataset.size(1) 75 | 76 | def __getitem__(self, index): 77 | x = self.dataset[index] 78 | x = self.transform(x) if self.transform is not None else x 79 | return x, 0 80 | 81 | 82 | class CelebAHQ(Dataset): 83 | TRAIN_LOC = 'data/celebahq/celeba256_train.pth' 84 | TEST_LOC = 'data/celebahq/celeba256_validation.pth' 85 | 86 | def __init__(self, train=True, transform=None): 87 | return super(CelebAHQ, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform) 88 | 89 | 90 | class Imagenet32(Dataset): 91 | TRAIN_LOC = 'data/imagenet32/train_32x32.pth' 92 | TEST_LOC = 'data/imagenet32/valid_32x32.pth' 93 | 94 | def __init__(self, train=True, transform=None): 95 | return super(Imagenet32, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform) 96 | 97 | 98 | class Imagenet64(Dataset): 99 | TRAIN_LOC = 'data/imagenet64/train_64x64.pth' 100 | TEST_LOC = 'data/imagenet64/valid_64x64.pth' 101 | 102 | def __init__(self, train=True, transform=None): 103 | return super(Imagenet64, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform, in_mem=False) 104 | -------------------------------------------------------------------------------- /lib/implicit_flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | import lib.layers as layers 6 | import lib.layers.base as base_layers 7 | 8 | ACT_FNS = { 9 | 'softplus': lambda b: nn.Softplus(), 10 | 'elu': lambda b: nn.ELU(inplace=b), 11 | 'swish': lambda b: base_layers.Swish(), 12 | 'lcube': lambda b: base_layers.LipschitzCube(), 13 | 'identity': lambda b: base_layers.Identity(), 14 | 'relu': lambda b: nn.ReLU(inplace=b), 15 | 'sin': lambda b: base_layers.Sin(), 16 | 'zero': lambda b: base_layers.Zero(), 17 | } 18 | 19 | 20 | class ImplicitFlow(nn.Module): 21 | 22 | def __init__( 23 | self, 24 | input_size, 25 | n_blocks=[16, 16], 26 | intermediate_dim=64, 27 | factor_out=True, 28 | quadratic=False, 29 | init_layer=None, 30 | actnorm=False, 31 | fc_actnorm=False, 32 | batchnorm=False, 33 | dropout=0, 34 | fc=False, 35 | coeff=0.9, 36 | vnorms='122f', 37 | n_lipschitz_iters=None, 38 | sn_atol=None, 39 | sn_rtol=None, 40 | n_power_series=5, 41 | n_dist='geometric', 42 | n_samples=1, 43 | kernels='3-1-3', 44 | activation_fn='elu', 45 | fc_end=True, 46 | fc_idim=128, 47 | n_exact_terms=0, 48 | preact=False, 49 | neumann_grad=True, 50 | grad_in_forward=False, 51 | first_resblock=True, 52 | learn_p=False, 53 | classification=False, 54 | classification_hdim=64, 55 | n_classes=10, 56 | ): 57 | super(ImplicitFlow, self).__init__() 58 | self.n_scale = min(len(n_blocks), self._calc_n_scale(input_size)) 59 | self.n_blocks = n_blocks 60 | self.intermediate_dim = intermediate_dim 61 | self.factor_out = factor_out 62 | self.quadratic = quadratic 63 | self.init_layer = init_layer 64 | self.actnorm = actnorm 65 | self.fc_actnorm = fc_actnorm 66 | self.batchnorm = batchnorm 67 | self.dropout = dropout 68 | self.fc = fc 69 | self.coeff = coeff 70 | self.vnorms = vnorms 71 | self.n_lipschitz_iters = n_lipschitz_iters 72 | self.sn_atol = sn_atol 73 | self.sn_rtol = sn_rtol 74 | self.n_power_series = n_power_series 75 | self.n_dist = n_dist 76 | self.n_samples = n_samples 77 | self.kernels = kernels 78 | self.activation_fn = activation_fn 79 | self.fc_end = fc_end 80 | self.fc_idim = fc_idim 81 | self.n_exact_terms = n_exact_terms 82 | self.preact = preact 83 | self.neumann_grad = neumann_grad 84 | self.grad_in_forward = grad_in_forward 85 | self.first_resblock = first_resblock 86 | self.learn_p = learn_p 87 | self.classification = classification 88 | self.classification_hdim = classification_hdim 89 | self.n_classes = n_classes 90 | 91 | if not self.n_scale > 0: 92 | raise ValueError('Could not compute number of scales for input of' 'size (%d,%d,%d,%d)' % input_size) 93 | 94 | self.transforms = self._build_net(input_size) 95 | 96 | self.dims = [o[1:] for o in self.calc_output_size(input_size)] 97 | 98 | if self.classification: 99 | self.build_multiscale_classifier(input_size) 100 | 101 | def _build_net(self, input_size): 102 | _, c, h, w = input_size 103 | transforms = [] 104 | for i in range(self.n_scale): 105 | transforms.append( 106 | StackedImplicitBlocks( 107 | initial_size=(c, h, w), 108 | idim=self.intermediate_dim, 109 | squeeze=(i < self.n_scale - 1), # don't squeeze last layer 110 | init_layer=self.init_layer if i == 0 else None, 111 | n_blocks=self.n_blocks[i], 112 | quadratic=self.quadratic, 113 | actnorm=self.actnorm, 114 | fc_actnorm=self.fc_actnorm, 115 | batchnorm=self.batchnorm, 116 | dropout=self.dropout, 117 | fc=self.fc, 118 | coeff=self.coeff, 119 | vnorms=self.vnorms, 120 | n_lipschitz_iters=self.n_lipschitz_iters, 121 | sn_atol=self.sn_atol, 122 | sn_rtol=self.sn_rtol, 123 | n_power_series=self.n_power_series, 124 | n_dist=self.n_dist, 125 | n_samples=self.n_samples, 126 | kernels=self.kernels, 127 | activation_fn=self.activation_fn, 128 | fc_end=self.fc_end, 129 | fc_idim=self.fc_idim, 130 | n_exact_terms=self.n_exact_terms, 131 | preact=self.preact, 132 | neumann_grad=self.neumann_grad, 133 | grad_in_forward=self.grad_in_forward, 134 | first_resblock=self.first_resblock and (i == 0), 135 | learn_p=self.learn_p, 136 | ) 137 | ) 138 | c, h, w = c * 2 if self.factor_out else c * 4, h // 2, w // 2 139 | return nn.ModuleList(transforms) 140 | 141 | def _calc_n_scale(self, input_size): 142 | _, _, h, w = input_size 143 | n_scale = 0 144 | while h >= 4 and w >= 4: 145 | n_scale += 1 146 | h = h // 2 147 | w = w // 2 148 | return n_scale 149 | 150 | def calc_output_size(self, input_size): 151 | n, c, h, w = input_size 152 | if not self.factor_out: 153 | k = self.n_scale - 1 154 | return [[n, c * 4**k, h // 2**k, w // 2**k]] 155 | output_sizes = [] 156 | for i in range(self.n_scale): 157 | if i < self.n_scale - 1: 158 | c *= 2 159 | h //= 2 160 | w //= 2 161 | output_sizes.append((n, c, h, w)) 162 | else: 163 | output_sizes.append((n, c, h, w)) 164 | return tuple(output_sizes) 165 | 166 | def build_multiscale_classifier(self, input_size): 167 | n, c, h, w = input_size 168 | hidden_shapes = [] 169 | for i in range(self.n_scale): 170 | if i < self.n_scale - 1: 171 | c *= 2 if self.factor_out else 4 172 | h //= 2 173 | w //= 2 174 | hidden_shapes.append((n, c, h, w)) 175 | 176 | classification_heads = [] 177 | for i, hshape in enumerate(hidden_shapes): 178 | classification_heads.append( 179 | nn.Sequential( 180 | nn.Conv2d(hshape[1], self.classification_hdim, 3, 1, 1), 181 | layers.ActNorm2d(self.classification_hdim), 182 | nn.ReLU(inplace=True), 183 | nn.AdaptiveAvgPool2d((1, 1)), 184 | ) 185 | ) 186 | self.classification_heads = nn.ModuleList(classification_heads) 187 | self.logit_layer = nn.Linear(self.classification_hdim * len(classification_heads), self.n_classes) 188 | 189 | def forward(self, x, logpx=None, inverse=False, classify=False, restore=False): 190 | if inverse: 191 | return self.inverse(x, logpx) 192 | out = [] 193 | if classify: class_outs = [] 194 | for idx in range(len(self.transforms)): 195 | if logpx is not None: 196 | x, logpx = self.transforms[idx].forward(x, logpx, restore=restore) 197 | else: 198 | x = self.transforms[idx].forward(x, restore=restore) 199 | if self.factor_out and (idx < len(self.transforms) - 1): 200 | d = x.size(1) // 2 201 | x, f = x[:, :d], x[:, d:] 202 | out.append(f) 203 | 204 | # Handle classification. 205 | if classify: 206 | if self.factor_out: 207 | class_outs.append(self.classification_heads[idx](f)) 208 | else: 209 | class_outs.append(self.classification_heads[idx](x)) 210 | 211 | out.append(x) 212 | out = torch.cat([o.view(o.size()[0], -1) for o in out], 1) 213 | output = out if logpx is None else (out, logpx) 214 | if classify: 215 | h = torch.cat(class_outs, dim=1).squeeze(-1).squeeze(-1) 216 | logits = self.logit_layer(h) 217 | return output, logits 218 | else: 219 | return output 220 | 221 | def inverse(self, z, logpz=None): 222 | if self.factor_out: 223 | z = z.view(z.shape[0], -1) 224 | zs = [] 225 | i = 0 226 | for dims in self.dims: 227 | s = np.prod(dims) 228 | zs.append(z[:, i:i + s]) 229 | i += s 230 | zs = [_z.view(_z.size()[0], *zsize) for _z, zsize in zip(zs, self.dims)] 231 | 232 | if logpz is None: 233 | z_prev = self.transforms[-1].inverse(zs[-1]) 234 | for idx in range(len(self.transforms) - 2, -1, -1): 235 | z_prev = torch.cat((z_prev, zs[idx]), dim=1) 236 | z_prev = self.transforms[idx].inverse(z_prev) 237 | return z_prev 238 | else: 239 | z_prev, logpz = self.transforms[-1].inverse(zs[-1], logpz) 240 | for idx in range(len(self.transforms) - 2, -1, -1): 241 | z_prev = torch.cat((z_prev, zs[idx]), dim=1) 242 | z_prev, logpz = self.transforms[idx].inverse(z_prev, logpz) 243 | return z_prev, logpz 244 | else: 245 | z = z.view(z.shape[0], *self.dims[-1]) 246 | for idx in range(len(self.transforms) - 1, -1, -1): 247 | if logpz is None: 248 | z = self.transforms[idx].inverse(z) 249 | else: 250 | z, logpz = self.transforms[idx].inverse(z, logpz) 251 | return z if logpz is None else (z, logpz) 252 | 253 | 254 | class StackedImplicitBlocks(layers.SequentialFlow): 255 | 256 | def __init__( 257 | self, 258 | initial_size, 259 | idim, 260 | squeeze=True, 261 | init_layer=None, 262 | n_blocks=1, 263 | quadratic=False, 264 | actnorm=False, 265 | fc_actnorm=False, 266 | batchnorm=False, 267 | dropout=0, 268 | fc=False, 269 | coeff=0.9, 270 | vnorms='122f', 271 | n_lipschitz_iters=None, 272 | sn_atol=None, 273 | sn_rtol=None, 274 | n_power_series=5, 275 | n_dist='geometric', 276 | n_samples=1, 277 | kernels='3-1-3', 278 | activation_fn='elu', 279 | fc_end=True, 280 | fc_nblocks=2, 281 | fc_idim=128, 282 | n_exact_terms=0, 283 | preact=False, 284 | neumann_grad=True, 285 | grad_in_forward=False, 286 | first_resblock=True, 287 | learn_p=False, 288 | ): 289 | 290 | chain = [] 291 | 292 | # Parse vnorms 293 | ps = [] 294 | for p in vnorms: 295 | if p == 'f': 296 | ps.append(float('inf')) 297 | else: 298 | ps.append(float(p)) 299 | domains, codomains = ps[:-1], ps[1:] 300 | assert len(domains) == len(kernels.split('-')) 301 | 302 | def _actnorm(size, fc): 303 | if fc: 304 | return FCWrapper(layers.ActNorm1d(size[0] * size[1] * size[2])) 305 | else: 306 | return layers.ActNorm2d(size[0]) 307 | 308 | def _quadratic_layer(initial_size, fc): 309 | if fc: 310 | c, h, w = initial_size 311 | dim = c * h * w 312 | return FCWrapper(layers.InvertibleLinear(dim)) 313 | else: 314 | return layers.InvertibleConv2d(initial_size[0]) 315 | 316 | def _lipschitz_layer(fc): 317 | return base_layers.get_linear if fc else base_layers.get_conv2d 318 | 319 | def _resblock(initial_size, fc, idim=idim, first_resblock=True): 320 | if fc: 321 | return layers.imBlock( 322 | FCNet( 323 | input_shape=initial_size, 324 | idim=idim, 325 | lipschitz_layer=_lipschitz_layer(True), 326 | nhidden=len(kernels.split('-')) - 1, 327 | coeff=coeff, 328 | domains=domains, 329 | codomains=codomains, 330 | n_iterations=n_lipschitz_iters, 331 | activation_fn=activation_fn, 332 | preact=preact, 333 | dropout=dropout, 334 | sn_atol=sn_atol, 335 | sn_rtol=sn_rtol, 336 | learn_p=learn_p, 337 | ), 338 | FCNet( 339 | input_shape=initial_size, 340 | idim=idim, 341 | lipschitz_layer=_lipschitz_layer(True), 342 | nhidden=len(kernels.split('-')) - 1, 343 | coeff=coeff, 344 | domains=domains, 345 | codomains=codomains, 346 | n_iterations=n_lipschitz_iters, 347 | activation_fn=activation_fn, 348 | preact=preact, 349 | dropout=dropout, 350 | sn_atol=sn_atol, 351 | sn_rtol=sn_rtol, 352 | learn_p=learn_p, 353 | ), 354 | n_power_series=n_power_series, 355 | n_dist=n_dist, 356 | n_samples=n_samples, 357 | n_exact_terms=n_exact_terms, 358 | neumann_grad=neumann_grad, 359 | grad_in_forward=grad_in_forward, 360 | ) 361 | else: 362 | def build_nnet(): 363 | ks = list(map(int, kernels.split('-'))) 364 | if learn_p: 365 | _domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(ks))] 366 | _codomains = _domains[1:] + [_domains[0]] 367 | else: 368 | _domains = domains 369 | _codomains = codomains 370 | nnet = [] 371 | if not first_resblock and preact: 372 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) 373 | nnet.append(ACT_FNS[activation_fn](False)) 374 | nnet.append( 375 | _lipschitz_layer(fc)( 376 | initial_size[0], idim, ks[0], 1, ks[0] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, 377 | domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol 378 | ) 379 | ) 380 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) 381 | nnet.append(ACT_FNS[activation_fn](True)) 382 | for i, k in enumerate(ks[1:-1]): 383 | nnet.append( 384 | _lipschitz_layer(fc)( 385 | idim, idim, k, 1, k // 2, coeff=coeff, n_iterations=n_lipschitz_iters, 386 | domain=_domains[i + 1], codomain=_codomains[i + 1], atol=sn_atol, rtol=sn_rtol 387 | ) 388 | ) 389 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) 390 | nnet.append(ACT_FNS[activation_fn](True)) 391 | if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True)) 392 | nnet.append( 393 | _lipschitz_layer(fc)( 394 | idim, initial_size[0], ks[-1], 1, ks[-1] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, 395 | domain=_domains[-1], codomain=_codomains[-1], atol=sn_atol, rtol=sn_rtol 396 | ) 397 | ) 398 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) 399 | return nn.Sequential(*nnet) 400 | return layers.imBlock( 401 | build_nnet(), 402 | build_nnet(), 403 | n_power_series=n_power_series, 404 | n_dist=n_dist, 405 | n_samples=n_samples, 406 | n_exact_terms=n_exact_terms, 407 | neumann_grad=neumann_grad, 408 | grad_in_forward=grad_in_forward, 409 | ) 410 | 411 | if init_layer is not None: chain.append(init_layer) 412 | if first_resblock and actnorm: chain.append(_actnorm(initial_size, fc)) 413 | if first_resblock and fc_actnorm: chain.append(_actnorm(initial_size, True)) 414 | 415 | if squeeze: 416 | c, h, w = initial_size 417 | for i in range(n_blocks): 418 | if quadratic: chain.append(_quadratic_layer(initial_size, fc)) 419 | chain.append(_resblock(initial_size, fc, first_resblock=first_resblock and (i == 0))) 420 | if actnorm: chain.append(_actnorm(initial_size, fc)) 421 | if fc_actnorm: chain.append(_actnorm(initial_size, True)) 422 | chain.append(layers.SqueezeLayer(2)) 423 | else: 424 | for i in range(n_blocks): 425 | if quadratic: chain.append(_quadratic_layer(initial_size, fc)) 426 | chain.append(_resblock(initial_size, fc, first_resblock=first_resblock and (i == 0))) 427 | if actnorm: chain.append(_actnorm(initial_size, fc)) 428 | if fc_actnorm: chain.append(_actnorm(initial_size, True)) 429 | # Use four fully connected layers at the end. 430 | if fc_end: 431 | for _ in range(fc_nblocks): 432 | chain.append(_resblock(initial_size, True, fc_idim)) 433 | if actnorm or fc_actnorm: chain.append(_actnorm(initial_size, True)) 434 | super(StackedImplicitBlocks, self).__init__(chain) 435 | 436 | 437 | class FCNet(nn.Module): 438 | 439 | def __init__( 440 | self, input_shape, idim, lipschitz_layer, nhidden, coeff, domains, codomains, n_iterations, activation_fn, 441 | preact, dropout, sn_atol, sn_rtol, learn_p, div_in=1 442 | ): 443 | super(FCNet, self).__init__() 444 | self.input_shape = input_shape 445 | c, h, w = self.input_shape 446 | dim = c * h * w 447 | nnet = [] 448 | last_dim = dim // div_in 449 | if preact: nnet.append(ACT_FNS[activation_fn](False)) 450 | if learn_p: 451 | domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(domains))] 452 | codomains = domains[1:] + [domains[0]] 453 | for i in range(nhidden): 454 | nnet.append( 455 | lipschitz_layer(last_dim, idim) if lipschitz_layer == nn.Linear else lipschitz_layer( 456 | last_dim, idim, coeff=coeff, n_iterations=n_iterations, domain=domains[i], codomain=codomains[i], 457 | atol=sn_atol, rtol=sn_rtol 458 | ) 459 | ) 460 | nnet.append(ACT_FNS[activation_fn](True)) 461 | last_dim = idim 462 | if dropout: nnet.append(nn.Dropout(dropout, inplace=True)) 463 | nnet.append( 464 | lipschitz_layer(last_dim, dim) if lipschitz_layer == nn.Linear else lipschitz_layer( 465 | last_dim, dim, coeff=coeff, n_iterations=n_iterations, domain=domains[-1], codomain=codomains[-1], 466 | atol=sn_atol, rtol=sn_rtol 467 | ) 468 | ) 469 | self.nnet = nn.Sequential(*nnet) 470 | 471 | def forward(self, x, restore=False): 472 | x = x.view(x.shape[0], -1) 473 | y = self.nnet(x) 474 | return y.view(y.shape[0], *self.input_shape) 475 | 476 | 477 | class FCWrapper(nn.Module): 478 | 479 | def __init__(self, fc_module): 480 | super(FCWrapper, self).__init__() 481 | self.fc_module = fc_module 482 | 483 | def forward(self, x, logpx=None, restore=False): 484 | shape = x.shape 485 | x = x.view(x.shape[0], -1) 486 | if logpx is None: 487 | y = self.fc_module(x) 488 | return y.view(*shape) 489 | else: 490 | y, logpy = self.fc_module(x, logpx) 491 | return y.view(*shape), logpy 492 | 493 | def inverse(self, y, logpy=None): 494 | shape = y.shape 495 | y = y.view(y.shape[0], -1) 496 | if logpy is None: 497 | x = self.fc_module.inverse(y) 498 | return x.view(*shape) 499 | else: 500 | x, logpx = self.fc_module.inverse(y, logpy) 501 | return x.view(*shape), logpx 502 | -------------------------------------------------------------------------------- /lib/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .act_norm import * 2 | from .container import * 3 | from .coupling import * 4 | from .elemwise import * 5 | from .iresblock import * 6 | from .implicit_block import * 7 | from .normalization import * 8 | from .squeeze import * 9 | from .glow import * 10 | -------------------------------------------------------------------------------- /lib/layers/act_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | 5 | __all__ = ['ActNorm1d', 'ActNorm2d'] 6 | 7 | 8 | class ActNormNd(nn.Module): 9 | 10 | def __init__(self, num_features, eps=1e-12): 11 | super(ActNormNd, self).__init__() 12 | self.num_features = num_features 13 | self.eps = eps 14 | self.weight = Parameter(torch.Tensor(num_features)) 15 | self.bias = Parameter(torch.Tensor(num_features)) 16 | self.register_buffer('initialized', torch.tensor(0)) 17 | 18 | @property 19 | def shape(self): 20 | raise NotImplementedError 21 | 22 | def forward(self, x, logpx=None, restore=None): 23 | c = x.size(1) 24 | 25 | if not self.initialized: 26 | with torch.no_grad(): 27 | # compute batch statistics 28 | x_t = x.transpose(0, 1).contiguous().view(c, -1) 29 | batch_mean = torch.mean(x_t, dim=1) 30 | batch_var = torch.var(x_t, dim=1) 31 | 32 | # for numerical issues 33 | batch_var = torch.max(batch_var, torch.tensor(0.2).to(batch_var)) 34 | 35 | self.bias.data.copy_(-batch_mean) 36 | self.weight.data.copy_(-0.5 * torch.log(batch_var)) 37 | self.initialized.fill_(1) 38 | 39 | bias = self.bias.view(*self.shape).expand_as(x) 40 | weight = self.weight.view(*self.shape).expand_as(x) 41 | 42 | y = (x + bias) * torch.exp(weight) 43 | 44 | if logpx is None: 45 | return y 46 | else: 47 | return y, logpx - self._logdetgrad(x) 48 | 49 | def inverse(self, y, logpy=None): 50 | assert self.initialized 51 | bias = self.bias.view(*self.shape).expand_as(y) 52 | weight = self.weight.view(*self.shape).expand_as(y) 53 | 54 | x = y * torch.exp(-weight) - bias 55 | 56 | if logpy is None: 57 | return x 58 | else: 59 | return x, logpy + self._logdetgrad(x) 60 | 61 | def _logdetgrad(self, x): 62 | return self.weight.view(*self.shape).expand(*x.size()).contiguous().view(x.size(0), -1).sum(1, keepdim=True) 63 | 64 | def __repr__(self): 65 | return ('{name}({num_features})'.format(name=self.__class__.__name__, **self.__dict__)) 66 | 67 | 68 | class ActNorm1d(ActNormNd): 69 | 70 | @property 71 | def shape(self): 72 | return [1, -1] 73 | 74 | 75 | class ActNorm2d(ActNormNd): 76 | 77 | @property 78 | def shape(self): 79 | return [1, -1, 1, 1] 80 | -------------------------------------------------------------------------------- /lib/layers/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .lipschitz import * 3 | from .mixed_lipschitz import * 4 | -------------------------------------------------------------------------------- /lib/layers/base/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class Sin(nn.Module): 8 | def __init__(self): 9 | super(Sin, self).__init__() 10 | 11 | def forward(self, x): 12 | return torch.sin(2. * math.pi * x) / math.pi * 0.5 13 | 14 | 15 | class Identity(nn.Module): 16 | 17 | def forward(self, x): 18 | return x 19 | 20 | class Zero(nn.Module): 21 | 22 | def forward(self, x): 23 | return torch.zeros_like(x).to(x) 24 | 25 | class FullSort(nn.Module): 26 | 27 | def forward(self, x): 28 | return torch.sort(x, 1)[0] 29 | 30 | 31 | class MaxMin(nn.Module): 32 | 33 | def forward(self, x): 34 | b, d = x.shape 35 | max_vals = torch.max(x.view(b, d // 2, 2), 2)[0] 36 | min_vals = torch.min(x.view(b, d // 2, 2), 2)[0] 37 | return torch.cat([max_vals, min_vals], 1) 38 | 39 | 40 | class LipschitzCube(nn.Module): 41 | 42 | def forward(self, x): 43 | return (x >= 1).to(x) * (x - 2 / 3) + (x <= -1).to(x) * (x + 2 / 3) + ((x > -1) * (x < 1)).to(x) * x**3 / 3 44 | 45 | 46 | class SwishFn(torch.autograd.Function): 47 | 48 | @staticmethod 49 | def forward(ctx, x, beta): 50 | beta_sigm = torch.sigmoid(beta * x) 51 | output = x * beta_sigm 52 | ctx.save_for_backward(x, output, beta) 53 | return output / 1.1 54 | 55 | @staticmethod 56 | def backward(ctx, grad_output): 57 | x, output, beta = ctx.saved_tensors 58 | beta_sigm = output / x 59 | grad_x = grad_output * (beta * output + beta_sigm * (1 - beta * output)) 60 | grad_beta = torch.sum(grad_output * (x * output - output * output)).expand_as(beta) 61 | return grad_x / 1.1, grad_beta / 1.1 62 | 63 | 64 | class Swish(nn.Module): 65 | 66 | def __init__(self): 67 | super(Swish, self).__init__() 68 | self.beta = nn.Parameter(torch.tensor([0.5])) 69 | 70 | def forward(self, x): 71 | return (x * torch.sigmoid_(x * F.softplus(self.beta))).div_(1.1) 72 | 73 | 74 | if __name__ == '__main__': 75 | 76 | m = Swish() 77 | xx = torch.linspace(-5, 5, 1000).requires_grad_(True) 78 | yy = m(xx) 79 | dd, dbeta = torch.autograd.grad(yy.sum() * 2, [xx, m.beta]) 80 | 81 | import matplotlib.pyplot as plt 82 | 83 | plt.plot(xx.detach().numpy(), yy.detach().numpy(), label='Func') 84 | plt.plot(xx.detach().numpy(), dd.detach().numpy(), label='Deriv') 85 | plt.plot(xx.detach().numpy(), torch.max(dd.detach().abs() - 1, torch.zeros_like(dd)).numpy(), label='|Deriv| > 1') 86 | plt.legend() 87 | plt.tight_layout() 88 | plt.show() 89 | -------------------------------------------------------------------------------- /lib/layers/base/mixed_lipschitz.py: -------------------------------------------------------------------------------- 1 | from torch._six import container_abcs 2 | from itertools import repeat 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | import torch.nn.functional as F 8 | 9 | __all__ = ['InducedNormLinear', 'InducedNormConv2d'] 10 | 11 | 12 | class InducedNormLinear(nn.Module): 13 | 14 | def __init__( 15 | self, in_features, out_features, bias=True, coeff=0.97, domain=2, codomain=2, n_iterations=None, atol=None, 16 | rtol=None, zero_init=False, **unused_kwargs 17 | ): 18 | del unused_kwargs 19 | super(InducedNormLinear, self).__init__() 20 | self.in_features = in_features 21 | self.out_features = out_features 22 | self.coeff = coeff 23 | self.n_iterations = n_iterations 24 | self.atol = atol 25 | self.rtol = rtol 26 | self.domain = domain 27 | self.codomain = codomain 28 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 29 | if bias: 30 | self.bias = nn.Parameter(torch.Tensor(out_features)) 31 | else: 32 | self.register_parameter('bias', None) 33 | self.reset_parameters(zero_init) 34 | 35 | with torch.no_grad(): 36 | domain, codomain = self.compute_domain_codomain() 37 | 38 | h, w = self.weight.shape 39 | self.register_buffer('scale', torch.tensor(0.)) 40 | self.register_buffer('u', normalize_u(self.weight.new_empty(h).normal_(0, 1), codomain)) 41 | self.register_buffer('v', normalize_v(self.weight.new_empty(w).normal_(0, 1), domain)) 42 | 43 | # Try different random seeds to find the best u and v. 44 | with torch.no_grad(): 45 | self.compute_weight(True, n_iterations=200, atol=None, rtol=None) 46 | best_scale = self.scale.clone() 47 | best_u, best_v = self.u.clone(), self.v.clone() 48 | if not (domain == 2 and codomain == 2): 49 | for _ in range(10): 50 | self.register_buffer('u', normalize_u(self.weight.new_empty(h).normal_(0, 1), codomain)) 51 | self.register_buffer('v', normalize_v(self.weight.new_empty(w).normal_(0, 1), domain)) 52 | self.compute_weight(True, n_iterations=200) 53 | if self.scale > best_scale: 54 | best_u, best_v = self.u.clone(), self.v.clone() 55 | self.u.copy_(best_u) 56 | self.v.copy_(best_v) 57 | 58 | def reset_parameters(self, zero_init=False): 59 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 60 | if zero_init: 61 | # normalize cannot handle zero weight in some cases. 62 | self.weight.data.div_(1000) 63 | if self.bias is not None: 64 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 65 | bound = 1 / math.sqrt(fan_in) 66 | init.uniform_(self.bias, -bound, bound) 67 | 68 | def compute_domain_codomain(self): 69 | if torch.is_tensor(self.domain): 70 | domain = asym_squash(self.domain) 71 | codomain = asym_squash(self.codomain) 72 | else: 73 | domain, codomain = self.domain, self.codomain 74 | return domain, codomain 75 | 76 | def compute_one_iter(self): 77 | domain, codomain = self.compute_domain_codomain() 78 | u = self.u.detach() 79 | v = self.v.detach() 80 | weight = self.weight.detach() 81 | u = normalize_u(torch.mv(weight, v), codomain) 82 | v = normalize_v(torch.mv(weight.t(), u), domain) 83 | return torch.dot(u, torch.mv(weight, v)) 84 | 85 | def compute_weight(self, update=True, n_iterations=None, atol=None, rtol=None): 86 | u = self.u 87 | v = self.v 88 | weight = self.weight 89 | 90 | if update: 91 | 92 | n_iterations = self.n_iterations if n_iterations is None else n_iterations 93 | atol = self.atol if atol is None else atol 94 | rtol = self.rtol if rtol is None else atol 95 | 96 | if n_iterations is None and (atol is None or rtol is None): 97 | raise ValueError('Need one of n_iteration or (atol, rtol).') 98 | 99 | max_itrs = 200 100 | if n_iterations is not None: 101 | max_itrs = n_iterations 102 | 103 | with torch.no_grad(): 104 | domain, codomain = self.compute_domain_codomain() 105 | for _ in range(max_itrs): 106 | # Algorithm from http://www.qetlab.com/InducedMatrixNorm. 107 | if n_iterations is None and atol is not None and rtol is not None: 108 | old_v = v.clone() 109 | old_u = u.clone() 110 | 111 | u = normalize_u(torch.mv(weight, v), codomain, out=u) 112 | v = normalize_v(torch.mv(weight.t(), u), domain, out=v) 113 | 114 | if n_iterations is None and atol is not None and rtol is not None: 115 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5) 116 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5) 117 | tol_u = atol + rtol * torch.max(u) 118 | tol_v = atol + rtol * torch.max(v) 119 | if err_u < tol_u and err_v < tol_v: 120 | break 121 | self.v.copy_(v) 122 | self.u.copy_(u) 123 | u = u.clone() 124 | v = v.clone() 125 | 126 | sigma = torch.dot(u, torch.mv(weight, v)) 127 | with torch.no_grad(): 128 | self.scale.copy_(sigma) 129 | # soft normalization: only when sigma larger than coeff 130 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff) 131 | weight = weight / factor 132 | return weight 133 | 134 | def forward(self, input): 135 | weight = self.compute_weight(update=False) 136 | return F.linear(input, weight, self.bias) 137 | 138 | def extra_repr(self): 139 | domain, codomain = self.compute_domain_codomain() 140 | return ( 141 | 'in_features={}, out_features={}, bias={}' 142 | ', coeff={}, domain={:.2f}, codomain={:.2f}, n_iters={}, atol={}, rtol={}, learnable_ord={}'.format( 143 | self.in_features, self.out_features, self.bias is not None, self.coeff, domain, codomain, 144 | self.n_iterations, self.atol, self.rtol, torch.is_tensor(self.domain) 145 | ) 146 | ) 147 | 148 | 149 | class InducedNormConv2d(nn.Module): 150 | 151 | def __init__( 152 | self, in_channels, out_channels, kernel_size, stride, padding, bias=True, coeff=0.97, domain=2, codomain=2, 153 | n_iterations=None, atol=None, rtol=None, **unused_kwargs 154 | ): 155 | del unused_kwargs 156 | super(InducedNormConv2d, self).__init__() 157 | self.in_channels = in_channels 158 | self.out_channels = out_channels 159 | self.kernel_size = _pair(kernel_size) 160 | self.stride = _pair(stride) 161 | self.padding = _pair(padding) 162 | self.coeff = coeff 163 | self.n_iterations = n_iterations 164 | self.domain = domain 165 | self.codomain = codomain 166 | self.atol = atol 167 | self.rtol = rtol 168 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size)) 169 | if bias: 170 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 171 | else: 172 | self.register_parameter('bias', None) 173 | self.reset_parameters() 174 | self.register_buffer('initialized', torch.tensor(0)) 175 | self.register_buffer('spatial_dims', torch.tensor([1., 1.])) 176 | self.register_buffer('scale', torch.tensor(0.)) 177 | self.register_buffer('u', self.weight.new_empty(self.out_channels)) 178 | self.register_buffer('v', self.weight.new_empty(self.in_channels)) 179 | 180 | def compute_domain_codomain(self): 181 | if torch.is_tensor(self.domain): 182 | domain = asym_squash(self.domain) 183 | codomain = asym_squash(self.codomain) 184 | else: 185 | domain, codomain = self.domain, self.codomain 186 | return domain, codomain 187 | 188 | def reset_parameters(self): 189 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 190 | if self.bias is not None: 191 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 192 | bound = 1 / math.sqrt(fan_in) 193 | init.uniform_(self.bias, -bound, bound) 194 | 195 | def _initialize_u_v(self): 196 | with torch.no_grad(): 197 | domain, codomain = self.compute_domain_codomain() 198 | if self.kernel_size == (1, 1): 199 | self.u.resize_(self.out_channels).normal_(0, 1) 200 | self.u.copy_(normalize_u(self.u, codomain)) 201 | self.v.resize_(self.in_channels).normal_(0, 1) 202 | self.v.copy_(normalize_v(self.v, domain)) 203 | else: 204 | c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item()) 205 | with torch.no_grad(): 206 | num_input_dim = c * h * w 207 | self.v.resize_(num_input_dim).normal_(0, 1) 208 | self.v.copy_(normalize_v(self.v, domain)) 209 | # forward call to infer the shape 210 | u = F.conv2d( 211 | self.v.view(1, c, h, w), self.weight, stride=self.stride, padding=self.padding, bias=None 212 | ) 213 | num_output_dim = u.shape[0] * u.shape[1] * u.shape[2] * u.shape[3] 214 | # overwrite u with random init 215 | self.u.resize_(num_output_dim).normal_(0, 1) 216 | self.u.copy_(normalize_u(self.u, codomain)) 217 | 218 | self.initialized.fill_(1) 219 | 220 | # Try different random seeds to find the best u and v. 221 | self.compute_weight(True) 222 | best_scale = self.scale.clone() 223 | best_u, best_v = self.u.clone(), self.v.clone() 224 | if not (domain == 2 and codomain == 2): 225 | for _ in range(10): 226 | if self.kernel_size == (1, 1): 227 | self.u.copy_(normalize_u(self.weight.new_empty(self.out_channels).normal_(0, 1), codomain)) 228 | self.v.copy_(normalize_v(self.weight.new_empty(self.in_channels).normal_(0, 1), domain)) 229 | else: 230 | self.u.copy_(normalize_u(torch.randn(num_output_dim).to(self.weight), codomain)) 231 | self.v.copy_(normalize_v(torch.randn(num_input_dim).to(self.weight), domain)) 232 | self.compute_weight(True, n_iterations=200) 233 | if self.scale > best_scale: 234 | best_u, best_v = self.u.clone(), self.v.clone() 235 | self.u.copy_(best_u) 236 | self.v.copy_(best_v) 237 | # These two lines are important, see https://pytorch.org/docs/master/_modules/torch/nn/utils/spectral_norm.html#spectral_norm 238 | self.u = self.u.clone(memory_format=torch.contiguous_format) 239 | self.v = self.v.clone(memory_format=torch.contiguous_format) 240 | 241 | def compute_one_iter(self): 242 | if not self.initialized: 243 | raise ValueError('Layer needs to be initialized first.') 244 | domain, codomain = self.compute_domain_codomain() 245 | if self.kernel_size == (1, 1): 246 | u = self.u.detach() 247 | v = self.v.detach() 248 | weight = self.weight.detach().view(self.out_channels, self.in_channels) 249 | u = normalize_u(torch.mv(weight, v), codomain) 250 | v = normalize_v(torch.mv(weight.t(), u), domain) 251 | return torch.dot(u, torch.mv(weight, v)) 252 | else: 253 | u = self.u.detach() 254 | v = self.v.detach() 255 | weight = self.weight.detach() 256 | c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item()) 257 | u_s = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None) 258 | out_shape = u_s.shape 259 | u = normalize_u(u_s.view(-1), codomain) 260 | v_s = F.conv_transpose2d( 261 | u.view(out_shape), weight, stride=self.stride, padding=self.padding, output_padding=0 262 | ) 263 | v = normalize_v(v_s.view(-1), domain) 264 | weight_v = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None) 265 | return torch.dot(u.view(-1), weight_v.view(-1)) 266 | 267 | def compute_weight(self, update=True, n_iterations=None, atol=None, rtol=None): 268 | if not self.initialized: 269 | self._initialize_u_v() 270 | 271 | if self.kernel_size == (1, 1): 272 | return self._compute_weight_1x1(update, n_iterations, atol, rtol) 273 | else: 274 | return self._compute_weight_kxk(update, n_iterations, atol, rtol) 275 | 276 | def _compute_weight_1x1(self, update=True, n_iterations=None, atol=None, rtol=None): 277 | n_iterations = self.n_iterations if n_iterations is None else n_iterations 278 | atol = self.atol if atol is None else atol 279 | rtol = self.rtol if rtol is None else atol 280 | 281 | if n_iterations is None and (atol is None or rtol is None): 282 | raise ValueError('Need one of n_iteration or (atol, rtol).') 283 | 284 | max_itrs = 200 285 | if n_iterations is not None: 286 | max_itrs = n_iterations 287 | 288 | u = self.u 289 | v = self.v 290 | weight = self.weight.view(self.out_channels, self.in_channels) 291 | if update: 292 | with torch.no_grad(): 293 | domain, codomain = self.compute_domain_codomain() 294 | itrs_used = 0 295 | for _ in range(max_itrs): 296 | old_v = v.clone() 297 | old_u = u.clone() 298 | 299 | u = normalize_u(torch.mv(weight, v), codomain, out=u) 300 | v = normalize_v(torch.mv(weight.t(), u), domain, out=v) 301 | 302 | itrs_used = itrs_used + 1 303 | 304 | if n_iterations is None and atol is not None and rtol is not None: 305 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5) 306 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5) 307 | tol_u = atol + rtol * torch.max(u) 308 | tol_v = atol + rtol * torch.max(v) 309 | if err_u < tol_u and err_v < tol_v: 310 | break 311 | if itrs_used > 0: 312 | if domain != 1 and domain != 2: 313 | self.v.copy_(v) 314 | if codomain != 2 and codomain != float('inf'): 315 | self.u.copy_(u) 316 | # These two lines are important, see https://pytorch.org/docs/master/_modules/torch/nn/utils/spectral_norm.html#spectral_norm 317 | u = u.clone(memory_format=torch.contiguous_format) 318 | v = v.clone(memory_format=torch.contiguous_format) 319 | 320 | sigma = torch.dot(u, torch.mv(weight, v)) 321 | with torch.no_grad(): 322 | self.scale.copy_(sigma) 323 | # soft normalization: only when sigma larger than coeff 324 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff) 325 | weight = weight / factor 326 | return weight.view(self.out_channels, self.in_channels, 1, 1) 327 | 328 | def _compute_weight_kxk(self, update=True, n_iterations=None, atol=None, rtol=None): 329 | n_iterations = self.n_iterations if n_iterations is None else n_iterations 330 | atol = self.atol if atol is None else atol 331 | rtol = self.rtol if rtol is None else atol 332 | 333 | if n_iterations is None and (atol is None or rtol is None): 334 | raise ValueError('Need one of n_iteration or (atol, rtol).') 335 | 336 | max_itrs = 200 337 | if n_iterations is not None: 338 | max_itrs = n_iterations 339 | 340 | u = self.u 341 | v = self.v 342 | weight = self.weight 343 | c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item()) 344 | if update: 345 | with torch.no_grad(): 346 | domain, codomain = self.compute_domain_codomain() 347 | itrs_used = 0 348 | for _ in range(max_itrs): 349 | old_u = u.clone() 350 | old_v = v.clone() 351 | 352 | u_s = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None) 353 | out_shape = u_s.shape 354 | u = normalize_u(u_s.view(-1), codomain, out=u) 355 | 356 | v_s = F.conv_transpose2d( 357 | u.view(out_shape), weight, stride=self.stride, padding=self.padding, output_padding=0 358 | ) 359 | v = normalize_v(v_s.view(-1), domain, out=v) 360 | 361 | itrs_used = itrs_used + 1 362 | if n_iterations is None and atol is not None and rtol is not None: 363 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5) 364 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5) 365 | tol_u = atol + rtol * torch.max(u) 366 | tol_v = atol + rtol * torch.max(v) 367 | if err_u < tol_u and err_v < tol_v: 368 | break 369 | if itrs_used > 0: 370 | if domain != 2: 371 | self.v.copy_(v) 372 | if codomain != 2: 373 | self.u.copy_(u) 374 | # These two lines are important, see https://pytorch.org/docs/master/_modules/torch/nn/utils/spectral_norm.html#spectral_norm 375 | v = v.clone(memory_format=torch.contiguous_format) 376 | u = u.clone(memory_format=torch.contiguous_format) 377 | 378 | weight_v = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None) 379 | weight_v = weight_v.view(-1) 380 | sigma = torch.dot(u.view(-1), weight_v) 381 | with torch.no_grad(): 382 | self.scale.copy_(sigma) 383 | # soft normalization: only when sigma larger than coeff 384 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff) 385 | weight = weight / factor 386 | return weight 387 | 388 | def forward(self, input): 389 | if not self.initialized: self.spatial_dims.copy_(torch.tensor(input.shape[2:4]).to(self.spatial_dims)) 390 | weight = self.compute_weight(update=False) 391 | return F.conv2d(input, weight, self.bias, self.stride, self.padding, 1, 1) 392 | 393 | def extra_repr(self): 394 | domain, codomain = self.compute_domain_codomain() 395 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') 396 | if self.padding != (0,) * len(self.padding): 397 | s += ', padding={padding}' 398 | if self.bias is None: 399 | s += ', bias=False' 400 | s += ', coeff={}, domain={:.2f}, codomain={:.2f}, n_iters={}, atol={}, rtol={}, learnable_ord={}'.format( 401 | self.coeff, domain, codomain, self.n_iterations, self.atol, self.rtol, torch.is_tensor(self.domain) 402 | ) 403 | return s.format(**self.__dict__) 404 | 405 | 406 | def projmax_(v): 407 | """Inplace argmax on absolute value.""" 408 | ind = torch.argmax(torch.abs(v)) 409 | v.zero_() 410 | v[ind] = 1 411 | return v 412 | 413 | 414 | def normalize_v(v, domain, out=None): 415 | if not torch.is_tensor(domain) and domain == 2: 416 | v = F.normalize(v, p=2, dim=0, out=out) 417 | elif domain == 1: 418 | v = projmax_(v) 419 | else: 420 | vabs = torch.abs(v) 421 | vph = v / vabs 422 | vph[torch.isnan(vph)] = 1 423 | vabs = vabs / torch.max(vabs) 424 | vabs = vabs**(1 / (domain - 1)) 425 | v = vph * vabs / vector_norm(vabs, domain) 426 | return v 427 | 428 | 429 | def normalize_u(u, codomain, out=None): 430 | if not torch.is_tensor(codomain) and codomain == 2: 431 | u = F.normalize(u, p=2, dim=0, out=out) 432 | elif codomain == float('inf'): 433 | u = projmax_(u) 434 | else: 435 | uabs = torch.abs(u) 436 | uph = u / uabs 437 | uph[torch.isnan(uph)] = 1 438 | uabs = uabs / torch.max(uabs) 439 | uabs = uabs**(codomain - 1) 440 | if codomain == 1: 441 | u = uph * uabs / vector_norm(uabs, float('inf')) 442 | else: 443 | u = uph * uabs / vector_norm(uabs, codomain / (codomain - 1)) 444 | return u 445 | 446 | 447 | def vector_norm(x, p): 448 | x = x.view(-1) 449 | return torch.sum(x**p)**(1 / p) 450 | 451 | 452 | def leaky_elu(x, a=0.3): 453 | return a * x + (1 - a) * F.elu(x) 454 | 455 | 456 | def asym_squash(x): 457 | return torch.tanh(-leaky_elu(-x + 0.5493061829986572)) * 2 + 3 458 | 459 | 460 | # def asym_squash(x): 461 | # return torch.tanh(x) / 2. + 2. 462 | 463 | 464 | def _ntuple(n): 465 | 466 | def parse(x): 467 | if isinstance(x, container_abcs.Iterable): 468 | return x 469 | return tuple(repeat(x, n)) 470 | 471 | return parse 472 | 473 | 474 | _single = _ntuple(1) 475 | _pair = _ntuple(2) 476 | _triple = _ntuple(3) 477 | _quadruple = _ntuple(4) 478 | 479 | if __name__ == '__main__': 480 | 481 | p = nn.Parameter(torch.tensor(2.1)) 482 | 483 | m = InducedNormConv2d(10, 2, 3, 1, 1, atol=1e-3, rtol=1e-3, domain=p, codomain=p) 484 | W = m.compute_weight() 485 | 486 | m.compute_one_iter().backward() 487 | print(p.grad) 488 | 489 | # m.weight.data.copy_(W) 490 | # W = m.compute_weight().cpu().detach().numpy() 491 | # import numpy as np 492 | # print( 493 | # '{} {} {}'.format( 494 | # np.linalg.norm(W, ord=2, axis=(0, 1)), 495 | # '>' if np.linalg.norm(W, ord=2, axis=(0, 1)) > m.scale else '<', 496 | # m.scale, 497 | # ) 498 | # ) 499 | -------------------------------------------------------------------------------- /lib/layers/base/utils.py: -------------------------------------------------------------------------------- 1 | from torch._six import container_abcs 2 | from itertools import repeat 3 | 4 | 5 | def _ntuple(n): 6 | 7 | def parse(x): 8 | if isinstance(x, container_abcs.Iterable): 9 | return x 10 | return tuple(repeat(x, n)) 11 | 12 | return parse 13 | 14 | 15 | _single = _ntuple(1) 16 | _pair = _ntuple(2) 17 | _triple = _ntuple(3) 18 | _quadruple = _ntuple(4) 19 | -------------------------------------------------------------------------------- /lib/layers/broyden.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as functional 4 | from torch.autograd import Function 5 | import numpy as np 6 | import pickle 7 | import sys 8 | import os 9 | from scipy.optimize import root 10 | import time 11 | from termcolor import colored 12 | 13 | import logging 14 | 15 | logger = logging.getLogger() 16 | 17 | 18 | def _safe_norm(v): 19 | if not torch.isfinite(v).all(): 20 | return np.inf 21 | return torch.norm(v) 22 | 23 | 24 | def scalar_search_armijo(phi, phi0, derphi0, c1=1e-4, alpha0=1, amin=0): 25 | ite = 0 26 | phi_a0 = phi(alpha0) # First do an update with step size 1 27 | if phi_a0 <= phi0 + c1*alpha0*derphi0: 28 | return alpha0, phi_a0, ite 29 | 30 | # Otherwise, compute the minimizer of a quadratic interpolant 31 | alpha1 = -(derphi0) * alpha0**2 / 2.0 / (phi_a0 - phi0 - derphi0 * alpha0) 32 | phi_a1 = phi(alpha1) 33 | 34 | # Otherwise loop with cubic interpolation until we find an alpha which 35 | # satisfies the first Wolfe condition (since we are backtracking, we will 36 | # assume that the value of alpha is not too small and satisfies the second 37 | # condition. 38 | while alpha1 > amin: # we are assuming alpha>0 is a descent direction 39 | factor = alpha0**2 * alpha1**2 * (alpha1-alpha0) 40 | a = alpha0**2 * (phi_a1 - phi0 - derphi0*alpha1) - \ 41 | alpha1**2 * (phi_a0 - phi0 - derphi0*alpha0) 42 | a = a / factor 43 | b = -alpha0**3 * (phi_a1 - phi0 - derphi0*alpha1) + \ 44 | alpha1**3 * (phi_a0 - phi0 - derphi0*alpha0) 45 | b = b / factor 46 | 47 | alpha2 = (-b + torch.sqrt(torch.abs(b**2 - 3 * a * derphi0))) / (3.0*a) 48 | phi_a2 = phi(alpha2) 49 | ite += 1 50 | 51 | if (phi_a2 <= phi0 + c1*alpha2*derphi0): 52 | return alpha2, phi_a2, ite 53 | 54 | if (alpha1 - alpha2) > alpha1 / 2.0 or (1 - alpha2/alpha1) < 0.96: 55 | alpha2 = alpha1 / 2.0 56 | 57 | alpha0 = alpha1 58 | alpha1 = alpha2 59 | phi_a0 = phi_a1 60 | phi_a1 = phi_a2 61 | 62 | # Failed to find a suitable step length 63 | return None, phi_a1, ite 64 | 65 | 66 | def line_search(update, x0, g0, g, nstep=0, on=True): 67 | """ 68 | `update` is the propsoed direction of update. 69 | Code adapted from scipy. 70 | """ 71 | tmp_s = [0] 72 | tmp_g0 = [g0] 73 | tmp_phi = [torch.norm(g0)**2] 74 | s_norm = torch.norm(x0) / torch.norm(update) 75 | 76 | def phi(s, store=True): 77 | if s == tmp_s[0]: 78 | return tmp_phi[0] # If the step size is so small... just return something 79 | x_est = x0 + s * update 80 | g0_new = g(x_est) 81 | phi_new = _safe_norm(g0_new)**2 82 | if store: 83 | tmp_s[0] = s 84 | tmp_g0[0] = g0_new 85 | tmp_phi[0] = phi_new 86 | return phi_new 87 | 88 | if on: 89 | s, phi1, ite = scalar_search_armijo(phi, tmp_phi[0], -tmp_phi[0], amin=1e-2) 90 | if (not on) or s is None: 91 | s = 1.0 92 | ite = 0 93 | 94 | x_est = x0 + s * update 95 | if s == tmp_s[0]: 96 | g0_new = tmp_g0[0] 97 | else: 98 | g0_new = g(x_est) 99 | return x_est, g0_new, x_est - x0, g0_new - g0, ite 100 | 101 | def rmatvec(part_Us, part_VTs, x): 102 | # Compute x^T(-I + UV^T) 103 | # x: (N, d) 104 | # part_Us: (N, d, threshold) 105 | # part_VTs: (N, threshold, d) 106 | if part_Us.nelement() == 0: 107 | return -x 108 | xTU = torch.einsum('bi, bij -> bj', x, part_Us) # (N, threshold) 109 | return -x + torch.einsum('bj, bji -> bi', xTU, part_VTs) # (N, d) 110 | 111 | 112 | def matvec(part_Us, part_VTs, x): 113 | # Compute (-I + UV^T)x 114 | # x: (N, d) 115 | # part_Us: (N, d, threshold) 116 | # part_VTs: (N, threshold, d) 117 | if part_Us.nelement() == 0: 118 | return -x 119 | VTx = torch.einsum('bji, bi -> bj', part_VTs, x) # (N, threshold) 120 | return -x + torch.einsum('bij, bj -> bi', part_Us, VTx) # (N, d) 121 | 122 | 123 | def broyden(g_, x0, threshold, eps, ls=False, name="unknown"): 124 | # LBFGS_thres = min(threshold, 20) 125 | LBFGS_thres = threshold 126 | 127 | x0_shape = x0.shape 128 | x0 = x0.view(x0_shape[0], -1) 129 | 130 | bsz, total_hsize = x0.size() 131 | eps = eps * np.sqrt(np.prod(x0.shape)) 132 | 133 | def g(x): 134 | return g_(x.view(x0_shape)).view(bsz, -1) 135 | 136 | x_est = x0 # (bsz, d) 137 | gx = g(x_est) # (bsz, d) 138 | nstep = 0 139 | tnstep = 0 140 | 141 | # For fast calculation of inv_jacobian (approximately) 142 | Us = torch.zeros(bsz, total_hsize, LBFGS_thres).to(x0) 143 | VTs = torch.zeros(bsz, LBFGS_thres, total_hsize).to(x0) 144 | update = -gx 145 | new_objective = init_objective = torch.norm(gx).item() 146 | prot_break = False 147 | trace = [init_objective] 148 | 149 | # To be used in protective breaks 150 | protect_thres = 1e6 151 | lowest = new_objective 152 | lowest_xest, lowest_gx, lowest_step = x_est, gx, nstep 153 | while new_objective >= eps and nstep < threshold: 154 | x_est, gx, delta_x, delta_gx, ite = line_search(update, x_est, gx, g, nstep=nstep, on=ls) 155 | nstep += 1 156 | tnstep += (ite+1) 157 | new_objective = torch.norm(gx).item() 158 | trace.append(new_objective) 159 | if new_objective < lowest: 160 | lowest_xest, lowest_gx = x_est.clone().detach(), gx.clone().detach() 161 | lowest = new_objective 162 | lowest_step = nstep 163 | if new_objective < eps: 164 | break 165 | if new_objective < 3*eps and nstep == threshold and np.max(trace[-threshold:]) / np.min(trace[-threshold:]) < 1.3: 166 | logger.info('Iterations exceeded 30 for broyden') 167 | # if there's hardly been any progress in the last 30 steps 168 | break 169 | if new_objective > init_objective * protect_thres: 170 | logger.info('Broyden failed') 171 | prot_break = True 172 | break 173 | 174 | part_Us, part_VTs = Us[:,:,:(nstep-1) % LBFGS_thres], VTs[:,:(nstep-1) % LBFGS_thres] 175 | vT = rmatvec(part_Us, part_VTs, delta_x) # (N, d) 176 | u = (delta_x - matvec(part_Us, part_VTs, delta_gx)) / torch.einsum('bi, bi -> b', vT, delta_gx)[:,None] 177 | vT[vT != vT] = 0 178 | u[u != u] = 0 179 | VTs[:,(nstep-1) % LBFGS_thres] = vT 180 | Us[:,:,(nstep-1) % LBFGS_thres] = u 181 | update = -matvec(Us[:,:,:nstep], VTs[:,:nstep], gx) 182 | 183 | Us, VTs = None, None 184 | return {"result": lowest_xest.view(x0_shape), 185 | "nstep": nstep, 186 | "tnstep": tnstep, 187 | "lowest_step": lowest_step, 188 | "diff": torch.norm(lowest_gx).item(), 189 | "diff_detail": torch.norm(lowest_gx, dim=1), 190 | "prot_break": prot_break, 191 | "trace": trace, 192 | "eps": eps, 193 | "threshold": threshold} 194 | 195 | 196 | def analyze_broyden(res_info, err=None, judge=True, name='forward', training=True, save_err=True): 197 | """ 198 | For debugging use only :-) 199 | """ 200 | res_est = res_info['result'] 201 | nstep = res_info['nstep'] 202 | diff = res_info['diff'] 203 | diff_detail = res_info['diff_detail'] 204 | prot_break = res_info['prot_break'] 205 | trace = res_info['trace'] 206 | eps = res_info['eps'] 207 | threshold = res_info['threshold'] 208 | if judge: 209 | return nstep >= threshold or (nstep == 0 and (diff != diff or diff > eps)) or prot_break or torch.isnan(res_est).any() 210 | 211 | assert (err is not None), "Must provide err information when not in judgment mode" 212 | prefix, color = ('', 'red') if name == 'forward' else ('back_', 'blue') 213 | eval_prefix = '' if training else 'eval_' 214 | 215 | # Case 1: A nan entry is produced in Broyden 216 | if torch.isnan(res_est).any(): 217 | msg = colored(f"WARNING: nan found in Broyden's {name} result. Diff: {diff}", color) 218 | print(msg) 219 | if save_err: pickle.dump(err, open(f'{prefix}{eval_prefix}nan.pkl', 'wb')) 220 | return (1, msg, res_info) 221 | 222 | # Case 2: Unknown problem with Broyden's method (probably due to nan update(s) to the weights) 223 | if nstep == 0 and (diff != diff or diff > eps): 224 | msg = colored(f"WARNING: Bad Broyden's method {name}. Why?? Diff: {diff}. STOP.", color) 225 | print(msg) 226 | if save_err: pickle.dump(err, open(f'{prefix}{eval_prefix}badbroyden.pkl', 'wb')) 227 | return (2, msg, res_info) 228 | 229 | # Case 3: Protective break during Broyden (so that it does not diverge to infinity) 230 | if prot_break: 231 | msg = colored(f"WARNING: Hit Protective Break in {name}. Diff: {diff}. Total Iter: {len(trace)}", color) 232 | print(msg) 233 | if save_err: pickle.dump(err, open(f'{prefix}{eval_prefix}prot_break.pkl', 'wb')) 234 | return (3, msg, res_info) 235 | 236 | return (-1, '', res_info) -------------------------------------------------------------------------------- /lib/layers/container.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SequentialFlow(nn.Module): 5 | """A generalized nn.Sequential container for normalizing flows. 6 | """ 7 | 8 | def __init__(self, layersList): 9 | super(SequentialFlow, self).__init__() 10 | self.chain = nn.ModuleList(layersList) 11 | 12 | def forward(self, x, logpx=None, restore=False): 13 | if logpx is None: 14 | for i in range(len(self.chain)): 15 | x = self.chain[i](x, restore=restore) 16 | return x 17 | else: 18 | for i in range(len(self.chain)): 19 | x, logpx = self.chain[i](x, logpx, restore=restore) 20 | return x, logpx 21 | 22 | def inverse(self, y, logpy=None): 23 | if logpy is None: 24 | for i in range(len(self.chain) - 1, -1, -1): 25 | y = self.chain[i].inverse(y) 26 | return y 27 | else: 28 | for i in range(len(self.chain) - 1, -1, -1): 29 | y, logpy = self.chain[i].inverse(y, logpy) 30 | return y, logpy 31 | 32 | 33 | class Inverse(nn.Module): 34 | 35 | def __init__(self, flow): 36 | super(Inverse, self).__init__() 37 | self.flow = flow 38 | 39 | def forward(self, x, logpx=None): 40 | return self.flow.inverse(x, logpx) 41 | 42 | def inverse(self, y, logpy=None): 43 | return self.flow.forward(y, logpy) 44 | -------------------------------------------------------------------------------- /lib/layers/coupling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import mask_utils 4 | 5 | __all__ = ['CouplingBlock', 'ChannelCouplingBlock', 'MaskedCouplingBlock'] 6 | 7 | 8 | class CouplingBlock(nn.Module): 9 | """Basic coupling layer for Tensors of shape (n,d). 10 | 11 | Forward computation: 12 | y_a = x_a 13 | y_b = y_b * exp(s(x_a)) + t(x_a) 14 | Inverse computation: 15 | x_a = y_a 16 | x_b = (y_b - t(y_a)) * exp(-s(y_a)) 17 | """ 18 | 19 | def __init__(self, dim, nnet, swap=False): 20 | """ 21 | Args: 22 | s (nn.Module) 23 | t (nn.Module) 24 | """ 25 | super(CouplingBlock, self).__init__() 26 | assert (dim % 2 == 0) 27 | self.d = dim // 2 28 | self.nnet = nnet 29 | self.swap = swap 30 | 31 | def func_s_t(self, x): 32 | f = self.nnet(x) 33 | s = f[:, :self.d] 34 | t = f[:, self.d:] 35 | return s, t 36 | 37 | def forward(self, x, logpx=None): 38 | """Forward computation of a simple coupling split on the axis=1. 39 | """ 40 | x_a = x[:, :self.d] if not self.swap else x[:, self.d:] 41 | x_b = x[:, self.d:] if not self.swap else x[:, :self.d] 42 | y_a, y_b, logdetgrad = self._forward_computation(x_a, x_b) 43 | y = [y_a, y_b] if not self.swap else [y_b, y_a] 44 | 45 | if logpx is None: 46 | return torch.cat(y, dim=1) 47 | else: 48 | return torch.cat(y, dim=1), logpx - logdetgrad.view(x.size(0), -1).sum(1, keepdim=True) 49 | 50 | def inverse(self, y, logpy=None): 51 | """Inverse computation of a simple coupling split on the axis=1. 52 | """ 53 | y_a = y[:, :self.d] if not self.swap else y[:, self.d:] 54 | y_b = y[:, self.d:] if not self.swap else y[:, :self.d] 55 | x_a, x_b, logdetgrad = self._inverse_computation(y_a, y_b) 56 | x = [x_a, x_b] if not self.swap else [x_b, x_a] 57 | if logpy is None: 58 | return torch.cat(x, dim=1) 59 | else: 60 | return torch.cat(x, dim=1), logpy + logdetgrad 61 | 62 | def _forward_computation(self, x_a, x_b): 63 | y_a = x_a 64 | s_a, t_a = self.func_s_t(x_a) 65 | scale = torch.sigmoid(s_a + 2.) 66 | y_b = x_b * scale + t_a 67 | logdetgrad = self._logdetgrad(scale) 68 | return y_a, y_b, logdetgrad 69 | 70 | def _inverse_computation(self, y_a, y_b): 71 | x_a = y_a 72 | s_a, t_a = self.func_s_t(y_a) 73 | scale = torch.sigmoid(s_a + 2.) 74 | x_b = (y_b - t_a) / scale 75 | logdetgrad = self._logdetgrad(scale) 76 | return x_a, x_b, logdetgrad 77 | 78 | def _logdetgrad(self, scale): 79 | """ 80 | Returns: 81 | Tensor (N, 1): containing ln |det J| where J is the jacobian 82 | """ 83 | return torch.log(scale).view(scale.shape[0], -1).sum(1, keepdim=True) 84 | 85 | def extra_repr(self): 86 | return 'dim={d}, swap={swap}'.format(**self.__dict__) 87 | 88 | 89 | class ChannelCouplingBlock(CouplingBlock): 90 | """Channel-wise coupling layer for images. 91 | """ 92 | 93 | def __init__(self, dim, nnet, mask_type='channel0'): 94 | if mask_type == 'channel0': 95 | swap = False 96 | elif mask_type == 'channel1': 97 | swap = True 98 | else: 99 | raise ValueError('Unknown mask type.') 100 | super(ChannelCouplingBlock, self).__init__(dim, nnet, swap) 101 | self.mask_type = mask_type 102 | 103 | def extra_repr(self): 104 | return 'dim={d}, mask_type={mask_type}'.format(**self.__dict__) 105 | 106 | 107 | class MaskedCouplingBlock(nn.Module): 108 | """Coupling layer for images implemented using masks. 109 | """ 110 | 111 | def __init__(self, dim, nnet, mask_type='checkerboard0'): 112 | nn.Module.__init__(self) 113 | self.d = dim 114 | self.nnet = nnet 115 | self.mask_type = mask_type 116 | 117 | def func_s_t(self, x): 118 | f = self.nnet(x) 119 | s = torch.sigmoid(f[:, :self.d] + 2.) 120 | t = f[:, self.d:] 121 | return s, t 122 | 123 | def forward(self, x, logpx=None): 124 | # get mask 125 | b = mask_utils.get_mask(x, mask_type=self.mask_type) 126 | 127 | # masked forward 128 | x_a = b * x 129 | s, t = self.func_s_t(x_a) 130 | y = (x * s + t) * (1 - b) + x_a 131 | 132 | if logpx is None: 133 | return y 134 | else: 135 | return y, logpx - self._logdetgrad(s, b) 136 | 137 | def inverse(self, y, logpy=None): 138 | # get mask 139 | b = mask_utils.get_mask(y, mask_type=self.mask_type) 140 | 141 | # masked forward 142 | y_a = b * y 143 | s, t = self.func_s_t(y_a) 144 | x = y_a + (1 - b) * (y - t) / s 145 | 146 | if logpy is None: 147 | return x 148 | else: 149 | return x, logpy + self._logdetgrad(s, b) 150 | 151 | def _logdetgrad(self, s, mask): 152 | return torch.log(s).mul_(1 - mask).view(s.shape[0], -1).sum(1, keepdim=True) 153 | 154 | def extra_repr(self): 155 | return 'dim={d}, mask_type={mask_type}'.format(**self.__dict__) 156 | -------------------------------------------------------------------------------- /lib/layers/elemwise.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | _DEFAULT_ALPHA = 1e-6 6 | 7 | 8 | class ZeroMeanTransform(nn.Module): 9 | 10 | def __init__(self): 11 | nn.Module.__init__(self) 12 | 13 | def forward(self, x, logpx=None): 14 | x = x - .5 15 | if logpx is None: 16 | return x 17 | return x, logpx 18 | 19 | def inverse(self, y, logpy=None): 20 | y = y + .5 21 | if logpy is None: 22 | return y 23 | return y, logpy 24 | 25 | 26 | class Normalize(nn.Module): 27 | 28 | def __init__(self, mean, std): 29 | nn.Module.__init__(self) 30 | self.register_buffer('mean', torch.as_tensor(mean, dtype=torch.float32)) 31 | self.register_buffer('std', torch.as_tensor(std, dtype=torch.float32)) 32 | 33 | def forward(self, x, logpx=None): 34 | y = x.clone() 35 | c = len(self.mean) 36 | y[:, :c].sub_(self.mean[None, :, None, None]).div_(self.std[None, :, None, None]) 37 | if logpx is None: 38 | return y 39 | else: 40 | return y, logpx - self._logdetgrad(x) 41 | 42 | def inverse(self, y, logpy=None): 43 | x = y.clone() 44 | c = len(self.mean) 45 | x[:, :c].mul_(self.std[None, :, None, None]).add_(self.mean[None, :, None, None]) 46 | if logpy is None: 47 | return x 48 | else: 49 | return x, logpy + self._logdetgrad(x) 50 | 51 | def _logdetgrad(self, x): 52 | logdetgrad = ( 53 | self.std.abs().log().mul_(-1).view(1, -1, 1, 1).expand(x.shape[0], len(self.std), x.shape[2], x.shape[3]) 54 | ) 55 | return logdetgrad.reshape(x.shape[0], -1).sum(-1, keepdim=True) 56 | 57 | 58 | class LogitTransform(nn.Module): 59 | """ 60 | The proprocessing step used in Real NVP: 61 | y = sigmoid(x) - a / (1 - 2a) 62 | x = logit(a + (1 - 2a)*y) 63 | """ 64 | 65 | def __init__(self, alpha=_DEFAULT_ALPHA): 66 | nn.Module.__init__(self) 67 | self.alpha = alpha 68 | 69 | def forward(self, x, logpx=None, restore=False): 70 | s = self.alpha + (1 - 2 * self.alpha) * x 71 | y = torch.log(s) - torch.log(1 - s) 72 | if logpx is None: 73 | return y 74 | return y, logpx - self._logdetgrad(x).view(x.size(0), -1).sum(1, keepdim=True) 75 | 76 | def inverse(self, y, logpy=None): 77 | x = (torch.sigmoid(y) - self.alpha) / (1 - 2 * self.alpha) 78 | if logpy is None: 79 | return x 80 | return x, logpy + self._logdetgrad(x).view(x.size(0), -1).sum(1, keepdim=True) 81 | 82 | def _logdetgrad(self, x): 83 | s = self.alpha + (1 - 2 * self.alpha) * x 84 | logdetgrad = -torch.log(s - s * s) + math.log(1 - 2 * self.alpha) 85 | return logdetgrad 86 | 87 | def __repr__(self): 88 | return ('{name}({alpha})'.format(name=self.__class__.__name__, **self.__dict__)) 89 | -------------------------------------------------------------------------------- /lib/layers/glow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class InvertibleLinear(nn.Module): 7 | 8 | def __init__(self, dim): 9 | super(InvertibleLinear, self).__init__() 10 | self.dim = dim 11 | self.weight = nn.Parameter(torch.eye(dim)[torch.randperm(dim)]) 12 | 13 | def forward(self, x, logpx=None): 14 | y = F.linear(x, self.weight) 15 | if logpx is None: 16 | return y 17 | else: 18 | return y, logpx - self._logdetgrad 19 | 20 | def inverse(self, y, logpy=None): 21 | x = F.linear(y, self.weight.inverse()) 22 | if logpy is None: 23 | return x 24 | else: 25 | return x, logpy + self._logdetgrad 26 | 27 | @property 28 | def _logdetgrad(self): 29 | return torch.log(torch.abs(torch.det(self.weight))) 30 | 31 | def extra_repr(self): 32 | return 'dim={}'.format(self.dim) 33 | 34 | 35 | class InvertibleConv2d(nn.Module): 36 | 37 | def __init__(self, dim): 38 | super(InvertibleConv2d, self).__init__() 39 | self.dim = dim 40 | self.weight = nn.Parameter(torch.eye(dim)[torch.randperm(dim)]) 41 | 42 | def forward(self, x, logpx=None): 43 | y = F.conv2d(x, self.weight.view(self.dim, self.dim, 1, 1)) 44 | if logpx is None: 45 | return y 46 | else: 47 | return y, logpx - self._logdetgrad.expand_as(logpx) * x.shape[2] * x.shape[3] 48 | 49 | def inverse(self, y, logpy=None): 50 | x = F.conv2d(y, self.weight.inverse().view(self.dim, self.dim, 1, 1)) 51 | if logpy is None: 52 | return x 53 | else: 54 | return x, logpy + self._logdetgrad.expand_as(logpy) * x.shape[2] * x.shape[3] 55 | 56 | @property 57 | def _logdetgrad(self): 58 | return torch.log(torch.abs(torch.det(self.weight))) 59 | 60 | def extra_repr(self): 61 | return 'dim={}'.format(self.dim) 62 | -------------------------------------------------------------------------------- /lib/layers/implicit_block.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from .broyden import broyden 7 | import copy 8 | import lib.layers.base as base_layers 9 | 10 | import logging 11 | 12 | logger = logging.getLogger() 13 | 14 | __all__ = ['imBlock'] 15 | 16 | 17 | def find_fixed_point(g, y, threshold=1000, eps=1e-5): 18 | x, x_prev = g(y), y 19 | i = 0 20 | tol = eps + eps * y.abs() 21 | while not torch.all((x - x_prev)**2 / tol < 1.): 22 | x, x_prev = g(x), x 23 | i += 1 24 | if i > threshold: 25 | logger.info(torch.abs(x - x_prev).max()) 26 | logger.info('Iterations exceeded 1000 for fixed point.') 27 | break 28 | return x 29 | 30 | def primes(n): 31 | primfac = [] 32 | d = 2 33 | while d*d <= n: 34 | while (n % d) == 0: 35 | primfac.append(d) # supposing you want multiple factors repeated 36 | n //= d 37 | d += 1 38 | if n > 1: 39 | primfac.append(n) 40 | return list(set(primfac)) 41 | 42 | def choose_prime(n): 43 | if n == 32 * 32 * 3: 44 | return 3079 45 | elif n == 2: 46 | return 2 47 | else: 48 | assert False, 'Please specify the prime or the power of a prime given {}'.format(n) 49 | 50 | 51 | class RootFind(Function): 52 | @staticmethod 53 | def f(nnet_z, nnet_x, z, x): 54 | return nnet_x(x) - nnet_z(z) 55 | 56 | @staticmethod 57 | def banach_find_root(nnet_z, nnet_x, z0, x, *args): 58 | eps = args[-2] 59 | threshold = args[-1] # Can also set this to be different, based on training/inference 60 | x_embed = nnet_x(x) + x 61 | g = lambda z: x_embed - nnet_z(z) 62 | z_est = find_fixed_point(g, z0, threshold=threshold, eps=eps) 63 | if threshold > 100: 64 | torch.cuda.empty_cache() 65 | return z_est.clone().detach() 66 | 67 | @staticmethod 68 | def broyden_find_root(nnet_z, nnet_x, z0, x, *args): 69 | eps = args[-2] 70 | threshold = args[-1] # Can also set this to be different, based on training/inference 71 | x_embed = nnet_x(x) + x 72 | g = lambda z: x_embed - nnet_z(z) - z 73 | result_info = broyden(g, torch.zeros_like(z0).to(z0), threshold=threshold, eps=eps, name="forward") 74 | if result_info['prot_break']: 75 | z_est = RootFind.banach_find_root(nnet_z, nnet_x, z0, x, eps, 1000) 76 | else: 77 | z_est = result_info['result'] 78 | if threshold > 100: 79 | torch.cuda.empty_cache() 80 | return z_est.clone().detach() 81 | 82 | @staticmethod 83 | def forward(ctx, nnet_z, nnet_x, z0, x, method, *args): 84 | if method == 'broyden': 85 | root_find = RootFind.broyden_find_root 86 | else: 87 | root_find = RootFind.banach_find_root 88 | ctx.args_len = len(args) 89 | with torch.no_grad(): 90 | z_est = root_find(nnet_z, nnet_x, z0, x, *args) 91 | 92 | # If one would like to analyze the convergence process (e.g., failures, stability), should 93 | # insert here or in broyden_find_root. 94 | return z_est 95 | 96 | @staticmethod 97 | def backward(ctx, grad_z): 98 | assert 0, 'Cannot backward to this function.' 99 | grad_args = [None for _ in range(ctx.args_len)] 100 | return (None, None, grad_z, None, *grad_args) 101 | 102 | 103 | class imBlock(nn.Module): 104 | 105 | def __init__( 106 | self, 107 | nnet_x, 108 | nnet_z, 109 | geom_p=0.5, 110 | lamb=2., 111 | n_power_series=None, 112 | exact_trace=False, 113 | brute_force=False, 114 | n_samples=1, 115 | n_exact_terms=2, 116 | n_exact_terms_test=20, 117 | n_dist='geometric', 118 | neumann_grad=True, 119 | grad_in_forward=True, 120 | eps_forward=1e-6, 121 | eps_backward=1e-10, 122 | eps_sample=1e-5, 123 | threshold=30, 124 | ): 125 | """ 126 | Args: 127 | nnet: a nn.Module 128 | n_power_series: number of power series. If not None, uses a biased approximation to logdet. 129 | exact_trace: if False, uses a Hutchinson trace estimator. Otherwise computes the exact full Jacobian. 130 | brute_force: Computes the exact logdet. Only available for 2D inputs. 131 | """ 132 | super(imBlock, self).__init__() 133 | 134 | self.nnet_x = nnet_x 135 | self.nnet_z = nnet_z 136 | self.nnet_x_copy = copy.deepcopy(self.nnet_x) 137 | self.nnet_z_copy = copy.deepcopy(self.nnet_z) 138 | for params in self.nnet_x_copy.parameters(): 139 | params.requires_grad_(False) 140 | for params in self.nnet_z_copy.parameters(): 141 | params.requires_grad_(False) 142 | 143 | self.n_dist = n_dist 144 | self.geom_p = nn.Parameter(torch.tensor(np.log(geom_p) - np.log(1. - geom_p))).float() 145 | self.lamb = nn.Parameter(torch.tensor(lamb)).float() 146 | self.n_samples = n_samples 147 | self.n_power_series = n_power_series 148 | self.exact_trace = exact_trace 149 | self.brute_force = brute_force 150 | self.n_exact_terms = n_exact_terms 151 | self.n_exact_terms_test = n_exact_terms_test 152 | self.grad_in_forward = grad_in_forward 153 | self.neumann_grad = neumann_grad 154 | self.eps_forward = eps_forward 155 | self.eps_backward = eps_backward 156 | self.eps_sample = eps_sample 157 | self.threshold = threshold 158 | 159 | # store the samples of n. 160 | self.register_buffer('last_n_samples', torch.zeros(self.n_samples)) 161 | self.register_buffer('last_firmom', torch.zeros(1)) 162 | self.register_buffer('last_secmom', torch.zeros(1)) 163 | 164 | 165 | class Backward(Function): 166 | """ 167 | A 'dummy' function that does nothing in the forward pass and perform implicit differentiation 168 | in the backward pass. Essentially a wrapper that provides backprop for the `imBlock` class. 169 | You should use this inner class in imBlock's forward() function by calling: 170 | 171 | self.Backward.apply(self.func, ...) 172 | 173 | """ 174 | @staticmethod 175 | def forward(ctx, nnet_z, nnet_x, z, x, *args): 176 | ctx.save_for_backward(z, x) 177 | ctx.nnet_z = nnet_z 178 | ctx.nnet_x = nnet_x 179 | ctx.args = args 180 | return z 181 | 182 | @staticmethod 183 | def backward(ctx, grad): 184 | torch.cuda.empty_cache() 185 | 186 | grad = grad.clone() 187 | z, x = ctx.saved_tensors 188 | args = ctx.args 189 | eps, threshold = args[-2:] 190 | 191 | nnet_z = ctx.nnet_z 192 | nnet_x = ctx.nnet_x 193 | z = z.clone().detach().requires_grad_() 194 | x = x.clone().detach().requires_grad_() 195 | 196 | with torch.enable_grad(): 197 | Fz = nnet_z(z) + z 198 | 199 | def g(x_): 200 | Fz.backward(x_, retain_graph=True) # Retain for future calls to g 201 | xJ = z.grad.clone().detach() 202 | z.grad.zero_() 203 | return xJ - grad 204 | 205 | dl_dh = torch.zeros_like(grad).to(grad) 206 | result_info = broyden(g, dl_dh, threshold=threshold, eps=eps, name="backward") 207 | dl_dh = result_info['result'] 208 | Fz.backward(torch.zeros_like(dl_dh), retain_graph=False) 209 | 210 | with torch.enable_grad(): 211 | Fx = nnet_x(x) + x 212 | Fx.backward(dl_dh) 213 | dl_dx = x.grad.clone().detach() 214 | x.grad.zero_() 215 | 216 | grad_args = [None for _ in range(len(args))] 217 | return (None, None, dl_dh, dl_dx, *grad_args) 218 | 219 | 220 | def forward(self, x, logpx=None, restore=False): 221 | z0 = x.clone().detach() 222 | if restore: 223 | with torch.no_grad(): 224 | _ = self.nnet_x_copy(z0) 225 | _ = self.nnet_z_copy(z0) 226 | z = RootFind.apply(self.nnet_z, self.nnet_x, z0, z0, 'broyden', self.eps_forward, self.threshold) 227 | z = RootFind.f(self.nnet_z, self.nnet_x, z.detach(), z0) + z0 # For backwarding to parameters in func 228 | self.nnet_x_copy.load_state_dict(self.nnet_x.state_dict()) 229 | self.nnet_z_copy.load_state_dict(self.nnet_z.state_dict()) 230 | z = self.Backward.apply(self.nnet_z_copy, self.nnet_x_copy, z, x, 'broyden', self.eps_backward, self.threshold) 231 | if logpx is None: 232 | return z 233 | else: 234 | return z, logpx - self._logdetgrad(z, x) 235 | 236 | def inverse(self, z, logpy=None): 237 | x0 = z.clone().detach() 238 | x = RootFind.apply(self.nnet_x, self.nnet_z, x0, z, 'broyden', self.eps_sample, self.threshold) 239 | # x = RootFind.apply(self.nnet_x, self.nnet_z, x0, z, 'banach', self.eps_sample, self.threshold) 240 | if logpy is None: 241 | return x 242 | else: 243 | return x, logpy + self._logdetgrad(z, x) 244 | 245 | def _logdetgrad(self, z, x): 246 | """Returns logdet|dz/dx|.""" 247 | 248 | with torch.enable_grad(): 249 | if (self.brute_force or not self.training) and (x.ndimension() == 2 and x.shape[1] <= 10): 250 | x = x.requires_grad_(True) 251 | z = z.requires_grad_(True) 252 | Fx = x + self.nnet_x(x) 253 | Jx = batch_jacobian(Fx, x) 254 | logdet_x = torch.logdet(Jx) 255 | 256 | Fz = z + self.nnet_z(z) 257 | Jz = batch_jacobian(Fz, z) 258 | logdet_z = torch.logdet(Jz) 259 | 260 | return (logdet_x - logdet_z).view(-1, 1) 261 | if self.n_dist == 'geometric': 262 | geom_p = torch.sigmoid(self.geom_p).item() 263 | sample_fn = lambda m: geometric_sample(geom_p, m) 264 | rcdf_fn = lambda k, offset: geometric_1mcdf(geom_p, k, offset) 265 | elif self.n_dist == 'poisson': 266 | lamb = self.lamb.item() 267 | sample_fn = lambda m: poisson_sample(lamb, m) 268 | rcdf_fn = lambda k, offset: poisson_1mcdf(lamb, k, offset) 269 | 270 | if self.training: 271 | if self.n_power_series is None: 272 | # Unbiased estimation. 273 | lamb = self.lamb.item() 274 | n_samples = sample_fn(self.n_samples) 275 | n_power_series = max(n_samples) + self.n_exact_terms 276 | coeff_fn = lambda k: 1 / rcdf_fn(k, self.n_exact_terms) * \ 277 | sum(n_samples >= k - self.n_exact_terms) / len(n_samples) 278 | else: 279 | # Truncated estimation. 280 | n_power_series = self.n_power_series 281 | coeff_fn = lambda k: 1. 282 | else: 283 | # Unbiased estimation with more exact terms. 284 | 285 | lamb = self.lamb.item() 286 | n_samples = sample_fn(self.n_samples) 287 | n_power_series = max(n_samples) + self.n_exact_terms_test 288 | coeff_fn = lambda k: 1 / rcdf_fn(k, self.n_exact_terms_test) * \ 289 | sum(n_samples >= k - self.n_exact_terms_test) / len(n_samples) 290 | 291 | if not self.exact_trace: 292 | #################################### 293 | # Power series with trace estimator. 294 | #################################### 295 | # vareps_x = torch.randn_like(x) 296 | # vareps_z = torch.randn_like(z) 297 | vareps_x = torch.distributions.bernoulli.Bernoulli(torch.Tensor([0.5])).sample(x.shape).reshape(x.shape).to(x) * 2 - 1 298 | vareps_z = torch.distributions.bernoulli.Bernoulli(torch.Tensor([0.5])).sample(z.shape).reshape(z.shape).to(z) * 2 - 1 299 | 300 | # Choose the type of estimator. 301 | if self.training and self.neumann_grad: 302 | estimator_fn = neumann_logdet_estimator 303 | else: 304 | estimator_fn = basic_logdet_estimator 305 | 306 | # Do backprop-in-forward to save memory. 307 | if self.training and self.grad_in_forward: 308 | logdet_x = mem_eff_wrapper( 309 | estimator_fn, self.nnet_x, x, n_power_series, vareps_x, coeff_fn, self.training 310 | ) 311 | logdet_z = mem_eff_wrapper( 312 | estimator_fn, self.nnet_z, z, n_power_series, vareps_z, coeff_fn, self.training 313 | ) 314 | logdetgrad = logdet_x - logdet_z 315 | else: 316 | x = x.requires_grad_(True) 317 | z = z.requires_grad_(True) 318 | Fx = self.nnet_x(x) 319 | Fz = self.nnet_z(z) 320 | logdet_x = estimator_fn(Fx, x, n_power_series, vareps_x, coeff_fn, self.training) 321 | logdet_z = estimator_fn(Fz, z, n_power_series, vareps_z, coeff_fn, self.training) 322 | logdetgrad = logdet_x - logdet_z 323 | else: 324 | ############################################ 325 | # Power series with exact trace computation. 326 | ############################################ 327 | x = x.requires_grad_(True) 328 | z = z.requires_grad_(True) 329 | Fx = self.nnet_x(x) 330 | Jx = batch_jacobian(Fx, x) 331 | logdetJx = batch_trace(Jx) 332 | Jx_k = Jx 333 | for k in range(2, n_power_series + 1): 334 | Jx_k = torch.bmm(Jx, Jx_k) 335 | logdetJx = logdetJx + (-1)**(k+1) / k * coeff_fn(k) * batch_trace(Jx_k) 336 | Fz = self.nnet_z(z) 337 | Jz = batch_jacobian(Fz, z) 338 | logdetJz = batch_trace(Jz) 339 | Jz_k = Jz 340 | for k in range(2, n_power_series + 1): 341 | Jz_k = torch.bmm(Jz, Jz_k) 342 | logdetJz = logdetJz + (-1)**(k+1) / k * coeff_fn(k) * batch_trace(Jz_k) 343 | logdetgrad = logdetJx - logdetJz 344 | 345 | if self.training and self.n_power_series is None: 346 | self.last_n_samples.copy_(torch.tensor(n_samples).to(self.last_n_samples)) 347 | estimator = logdetgrad.detach() 348 | self.last_firmom.copy_(torch.mean(estimator).to(self.last_firmom)) 349 | self.last_secmom.copy_(torch.mean(estimator**2).to(self.last_secmom)) 350 | return logdetgrad.view(-1, 1) 351 | 352 | def extra_repr(self): 353 | return 'dist={}, n_samples={}, n_power_series={}, neumann_grad={}, exact_trace={}, brute_force={}, neumann_grad={}, grad_in_forward={}'.format( 354 | self.n_dist, self.n_samples, self.n_power_series, self.neumann_grad, self.exact_trace, self.brute_force, self.neumann_grad, self.grad_in_forward 355 | ) 356 | 357 | 358 | def batch_jacobian(g, x, create_graph=True): 359 | jac = [] 360 | for d in range(g.shape[1]): 361 | jac.append(torch.autograd.grad(torch.sum(g[:, d]), x, create_graph=create_graph)[0].view(x.shape[0], 1, x.shape[1])) 362 | return torch.cat(jac, 1) 363 | 364 | 365 | def batch_trace(M): 366 | return M.view(M.shape[0], -1)[:, ::M.shape[1] + 1].sum(1) 367 | 368 | 369 | 370 | ##################### 371 | # Logdet Estimators 372 | ##################### 373 | class MemoryEfficientLogDetEstimator(torch.autograd.Function): 374 | 375 | @staticmethod 376 | def forward(ctx, estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *g_params): 377 | ctx.training = training 378 | with torch.enable_grad(): 379 | x = x.detach().requires_grad_(True) 380 | g = gnet(x) 381 | ctx.g = g 382 | ctx.x = x 383 | logdetgrad = estimator_fn(g, x, n_power_series, vareps, coeff_fn, training) 384 | 385 | if training: 386 | grad_x, *grad_params = torch.autograd.grad( 387 | logdetgrad.sum(), (x,) + g_params, retain_graph=True, allow_unused=True 388 | ) 389 | if grad_x is None: 390 | grad_x = torch.zeros_like(x) 391 | ctx.save_for_backward(grad_x, *g_params, *grad_params) 392 | 393 | return safe_detach(logdetgrad) 394 | 395 | @staticmethod 396 | def backward(ctx, grad_logdetgrad): 397 | training = ctx.training 398 | if not training: 399 | raise ValueError('Provide training=True if using backward.') 400 | 401 | with torch.enable_grad(): 402 | grad_x, *params_and_grad = ctx.saved_tensors 403 | g, x = ctx.g, ctx.x 404 | 405 | # Precomputed gradients. 406 | g_params = params_and_grad[:len(params_and_grad) // 2] 407 | grad_params = params_and_grad[len(params_and_grad) // 2:] 408 | 409 | # Update based on gradient from logdetgrad. 410 | dL = grad_logdetgrad[0].detach() 411 | with torch.no_grad(): 412 | grad_x.mul_(dL) 413 | grad_params = tuple([g.mul_(dL) if g is not None else None for g in grad_params]) 414 | 415 | return (None, None, grad_x, None, None, None, None) + grad_params 416 | 417 | 418 | def basic_logdet_estimator(g, x, n_power_series, vareps, coeff_fn, training): 419 | vjp = vareps 420 | logdetgrad = torch.tensor(0.).to(x) 421 | for k in range(1, n_power_series + 1): 422 | vjp = torch.autograd.grad(g, x, vjp, create_graph=training, retain_graph=True)[0] 423 | tr = torch.sum(vjp.view(x.shape[0], -1) * vareps.view(x.shape[0], -1), 1) 424 | delta = (-1)**(k + 1) / k * coeff_fn(k) * tr 425 | logdetgrad = logdetgrad + delta 426 | return logdetgrad 427 | 428 | 429 | def neumann_logdet_estimator(g, x, n_power_series, vareps, coeff_fn, training): 430 | vjp = vareps 431 | neumann_vjp = vareps 432 | with torch.no_grad(): 433 | for k in range(1, n_power_series + 1): 434 | vjp = torch.autograd.grad(g, x, vjp, retain_graph=True)[0] 435 | neumann_vjp = neumann_vjp + (-1)**k * coeff_fn(k) * vjp 436 | vjp_jac = torch.autograd.grad(g, x, neumann_vjp, create_graph=training)[0] 437 | logdetgrad = torch.sum(vjp_jac.view(x.shape[0], -1) * vareps.view(x.shape[0], -1), 1) 438 | return logdetgrad 439 | 440 | 441 | def mem_eff_wrapper(estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training): 442 | 443 | # We need this in order to access the variables inside this module, 444 | # since we have no other way of getting variables along the execution path. 445 | if not isinstance(gnet, nn.Module): 446 | raise ValueError('g is required to be an instance of nn.Module.') 447 | 448 | return MemoryEfficientLogDetEstimator.apply( 449 | estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *list(gnet.parameters()) 450 | ) 451 | 452 | 453 | # -------- Helper distribution functions -------- 454 | # These take python ints or floats, not PyTorch tensors. 455 | 456 | 457 | def geometric_sample(p, n_samples): 458 | return np.random.geometric(p, n_samples) 459 | 460 | 461 | def geometric_1mcdf(p, k, offset): 462 | if k <= offset: 463 | return 1. 464 | else: 465 | k = k - offset 466 | """P(n >= k)""" 467 | return (1 - p)**max(k - 1, 0) 468 | 469 | 470 | def poisson_sample(lamb, n_samples): 471 | return np.random.poisson(lamb, n_samples) 472 | 473 | 474 | def poisson_1mcdf(lamb, k, offset): 475 | if k <= offset: 476 | return 1. 477 | else: 478 | k = k - offset 479 | """P(n >= k)""" 480 | sum = 1. 481 | for i in range(1, k): 482 | sum += lamb**i / math.factorial(i) 483 | return 1 - np.exp(-lamb) * sum 484 | 485 | 486 | def sample_rademacher_like(y): 487 | return torch.randint(low=0, high=2, size=y.shape).to(y) * 2 - 1 488 | 489 | 490 | # -------------- Helper functions -------------- 491 | 492 | 493 | def safe_detach(tensor): 494 | return tensor.detach().requires_grad_(tensor.requires_grad) 495 | 496 | 497 | def _flatten(sequence): 498 | flat = [p.reshape(-1) for p in sequence] 499 | return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) 500 | 501 | 502 | def _flatten_convert_none_to_zeros(sequence, like_sequence): 503 | flat = [p.reshape(-1) if p is not None else torch.zeros_like(q).view(-1) for p, q in zip(sequence, like_sequence)] 504 | return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) 505 | -------------------------------------------------------------------------------- /lib/layers/iresblock.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | import logging 7 | 8 | logger = logging.getLogger() 9 | 10 | __all__ = ['iResBlock'] 11 | 12 | 13 | class iResBlock(nn.Module): 14 | 15 | def __init__( 16 | self, 17 | nnet, 18 | geom_p=0.5, 19 | lamb=2., 20 | n_power_series=None, 21 | exact_trace=False, 22 | brute_force=False, 23 | n_samples=1, 24 | n_exact_terms=2, 25 | n_dist='geometric', 26 | neumann_grad=True, 27 | grad_in_forward=False, 28 | ): 29 | """ 30 | Args: 31 | nnet: a nn.Module 32 | n_power_series: number of power series. If not None, uses a biased approximation to logdet. 33 | exact_trace: if False, uses a Hutchinson trace estimator. Otherwise computes the exact full Jacobian. 34 | brute_force: Computes the exact logdet. Only available for 2D inputs. 35 | """ 36 | nn.Module.__init__(self) 37 | self.nnet = nnet 38 | self.n_dist = n_dist 39 | self.geom_p = nn.Parameter(torch.tensor(np.log(geom_p) - np.log(1. - geom_p))) 40 | self.lamb = nn.Parameter(torch.tensor(lamb)) 41 | self.n_samples = n_samples 42 | self.n_power_series = n_power_series 43 | self.exact_trace = exact_trace 44 | self.brute_force = brute_force 45 | self.n_exact_terms = n_exact_terms 46 | self.grad_in_forward = grad_in_forward 47 | self.neumann_grad = neumann_grad 48 | 49 | # store the samples of n. 50 | self.register_buffer('last_n_samples', torch.zeros(self.n_samples)) 51 | self.register_buffer('last_firmom', torch.zeros(1)) 52 | self.register_buffer('last_secmom', torch.zeros(1)) 53 | 54 | def forward(self, x, logpx=None): 55 | if logpx is None: 56 | y = x + self.nnet(x) 57 | return y 58 | else: 59 | g, logdetgrad = self._logdetgrad(x) 60 | return x + g, logpx - logdetgrad 61 | 62 | def inverse(self, y, logpy=None): 63 | x = self._inverse_fixed_point(y) 64 | if logpy is None: 65 | return x 66 | else: 67 | return x, logpy + self._logdetgrad(x)[1] 68 | 69 | def _inverse_fixed_point(self, y, atol=1e-5, rtol=1e-5): 70 | x, x_prev = y - self.nnet(y), y 71 | i = 0 72 | tol = atol + y.abs() * rtol 73 | while not torch.all((x - x_prev)**2 / tol < 1): 74 | x, x_prev = y - self.nnet(x), x 75 | i += 1 76 | if i > 1000: 77 | logger.info('Iterations exceeded 1000 for inverse.') 78 | break 79 | return x 80 | 81 | def _logdetgrad(self, x): 82 | """Returns g(x) and logdet|d(x+g(x))/dx|.""" 83 | 84 | with torch.enable_grad(): 85 | if (self.brute_force or not self.training) and (x.ndimension() == 2 and x.shape[1] == 2): 86 | ########################################### 87 | # Brute-force compute Jacobian determinant. 88 | ########################################### 89 | x = x.requires_grad_(True) 90 | g = self.nnet(x) 91 | # Brute-force logdet only available for 2D. 92 | jac = batch_jacobian(g, x) 93 | batch_dets = (jac[:, 0, 0] + 1) * (jac[:, 1, 1] + 1) - jac[:, 0, 1] * jac[:, 1, 0] 94 | return g, torch.log(torch.abs(batch_dets)).view(-1, 1) 95 | 96 | if self.n_dist == 'geometric': 97 | geom_p = torch.sigmoid(self.geom_p).item() 98 | sample_fn = lambda m: geometric_sample(geom_p, m) 99 | rcdf_fn = lambda k, offset: geometric_1mcdf(geom_p, k, offset) 100 | elif self.n_dist == 'poisson': 101 | lamb = self.lamb.item() 102 | sample_fn = lambda m: poisson_sample(lamb, m) 103 | rcdf_fn = lambda k, offset: poisson_1mcdf(lamb, k, offset) 104 | 105 | if self.training: 106 | if self.n_power_series is None: 107 | # Unbiased estimation. 108 | lamb = self.lamb.item() 109 | n_samples = sample_fn(self.n_samples) 110 | n_power_series = max(n_samples) + self.n_exact_terms 111 | coeff_fn = lambda k: 1 / rcdf_fn(k, self.n_exact_terms) * \ 112 | sum(n_samples >= k - self.n_exact_terms) / len(n_samples) 113 | else: 114 | # Truncated estimation. 115 | n_power_series = self.n_power_series 116 | coeff_fn = lambda k: 1. 117 | else: 118 | # Unbiased estimation with more exact terms. 119 | lamb = self.lamb.item() 120 | n_samples = sample_fn(self.n_samples) 121 | n_power_series = max(n_samples) + 20 122 | coeff_fn = lambda k: 1 / rcdf_fn(k, 20) * \ 123 | sum(n_samples >= k - 20) / len(n_samples) 124 | 125 | if not self.exact_trace: 126 | #################################### 127 | # Power series with trace estimator. 128 | #################################### 129 | vareps = torch.randn_like(x) 130 | 131 | # Choose the type of estimator. 132 | if self.training and self.neumann_grad: 133 | estimator_fn = neumann_logdet_estimator 134 | else: 135 | estimator_fn = basic_logdet_estimator 136 | 137 | # Do backprop-in-forward to save memory. 138 | if self.training and self.grad_in_forward: 139 | g, logdetgrad = mem_eff_wrapper( 140 | estimator_fn, self.nnet, x, n_power_series, vareps, coeff_fn, self.training 141 | ) 142 | else: 143 | x = x.requires_grad_(True) 144 | g = self.nnet(x) 145 | logdetgrad = estimator_fn(g, x, n_power_series, vareps, coeff_fn, self.training) 146 | else: 147 | ############################################ 148 | # Power series with exact trace computation. 149 | ############################################ 150 | x = x.requires_grad_(True) 151 | g = self.nnet(x) 152 | jac = batch_jacobian(g, x) 153 | logdetgrad = batch_trace(jac) 154 | jac_k = jac 155 | for k in range(2, n_power_series + 1): 156 | jac_k = torch.bmm(jac, jac_k) 157 | logdetgrad = logdetgrad + (-1)**(k+1) / k * coeff_fn(k) * batch_trace(jac_k) 158 | 159 | if self.training and self.n_power_series is None: 160 | self.last_n_samples.copy_(torch.tensor(n_samples).to(self.last_n_samples)) 161 | estimator = logdetgrad.detach() 162 | self.last_firmom.copy_(torch.mean(estimator).to(self.last_firmom)) 163 | self.last_secmom.copy_(torch.mean(estimator**2).to(self.last_secmom)) 164 | return g, logdetgrad.view(-1, 1) 165 | 166 | def extra_repr(self): 167 | return 'dist={}, n_samples={}, n_power_series={}, neumann_grad={}, exact_trace={}, brute_force={}'.format( 168 | self.n_dist, self.n_samples, self.n_power_series, self.neumann_grad, self.exact_trace, self.brute_force 169 | ) 170 | 171 | 172 | def batch_jacobian(g, x): 173 | jac = [] 174 | for d in range(g.shape[1]): 175 | jac.append(torch.autograd.grad(torch.sum(g[:, d]), x, create_graph=True)[0].view(x.shape[0], 1, x.shape[1])) 176 | return torch.cat(jac, 1) 177 | 178 | 179 | def batch_trace(M): 180 | return M.view(M.shape[0], -1)[:, ::M.shape[1] + 1].sum(1) 181 | 182 | 183 | ##################### 184 | # Logdet Estimators 185 | ##################### 186 | class MemoryEfficientLogDetEstimator(torch.autograd.Function): 187 | 188 | @staticmethod 189 | def forward(ctx, estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *g_params): 190 | ctx.training = training 191 | with torch.enable_grad(): 192 | x = x.detach().requires_grad_(True) 193 | g = gnet(x) 194 | ctx.g = g 195 | ctx.x = x 196 | logdetgrad = estimator_fn(g, x, n_power_series, vareps, coeff_fn, training) 197 | 198 | if training: 199 | grad_x, *grad_params = torch.autograd.grad( 200 | logdetgrad.sum(), (x,) + g_params, retain_graph=True, allow_unused=True 201 | ) 202 | if grad_x is None: 203 | grad_x = torch.zeros_like(x) 204 | ctx.save_for_backward(grad_x, *g_params, *grad_params) 205 | 206 | return safe_detach(g), safe_detach(logdetgrad) 207 | 208 | @staticmethod 209 | def backward(ctx, grad_g, grad_logdetgrad): 210 | training = ctx.training 211 | if not training: 212 | raise ValueError('Provide training=True if using backward.') 213 | 214 | with torch.enable_grad(): 215 | grad_x, *params_and_grad = ctx.saved_tensors 216 | g, x = ctx.g, ctx.x 217 | 218 | # Precomputed gradients. 219 | g_params = params_and_grad[:len(params_and_grad) // 2] 220 | grad_params = params_and_grad[len(params_and_grad) // 2:] 221 | 222 | dg_x, *dg_params = torch.autograd.grad(g, [x] + g_params, grad_g, allow_unused=True) 223 | 224 | # Update based on gradient from logdetgrad. 225 | dL = grad_logdetgrad[0].detach() 226 | with torch.no_grad(): 227 | grad_x.mul_(dL) 228 | grad_params = tuple([g.mul_(dL) if g is not None else None for g in grad_params]) 229 | 230 | # Update based on gradient from g. 231 | with torch.no_grad(): 232 | grad_x.add_(dg_x) 233 | grad_params = tuple([dg.add_(djac) if djac is not None else dg for dg, djac in zip(dg_params, grad_params)]) 234 | 235 | return (None, None, grad_x, None, None, None, None) + grad_params 236 | 237 | 238 | def basic_logdet_estimator(g, x, n_power_series, vareps, coeff_fn, training): 239 | vjp = vareps 240 | logdetgrad = torch.tensor(0.).to(x) 241 | for k in range(1, n_power_series + 1): 242 | vjp = torch.autograd.grad(g, x, vjp, create_graph=training, retain_graph=True)[0] 243 | tr = torch.sum(vjp.view(x.shape[0], -1) * vareps.view(x.shape[0], -1), 1) 244 | delta = (-1)**(k + 1) / k * coeff_fn(k) * tr 245 | logdetgrad = logdetgrad + delta 246 | return logdetgrad 247 | 248 | 249 | def neumann_logdet_estimator(g, x, n_power_series, vareps, coeff_fn, training): 250 | vjp = vareps 251 | neumann_vjp = vareps 252 | with torch.no_grad(): 253 | for k in range(1, n_power_series + 1): 254 | vjp = torch.autograd.grad(g, x, vjp, retain_graph=True)[0] 255 | neumann_vjp = neumann_vjp + (-1)**k * coeff_fn(k) * vjp 256 | vjp_jac = torch.autograd.grad(g, x, neumann_vjp, create_graph=training)[0] 257 | logdetgrad = torch.sum(vjp_jac.view(x.shape[0], -1) * vareps.view(x.shape[0], -1), 1) 258 | return logdetgrad 259 | 260 | 261 | def mem_eff_wrapper(estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training): 262 | 263 | # We need this in order to access the variables inside this module, 264 | # since we have no other way of getting variables along the execution path. 265 | if not isinstance(gnet, nn.Module): 266 | raise ValueError('g is required to be an instance of nn.Module.') 267 | 268 | return MemoryEfficientLogDetEstimator.apply( 269 | estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *list(gnet.parameters()) 270 | ) 271 | 272 | 273 | # -------- Helper distribution functions -------- 274 | # These take python ints or floats, not PyTorch tensors. 275 | 276 | 277 | def geometric_sample(p, n_samples): 278 | return np.random.geometric(p, n_samples) 279 | 280 | 281 | def geometric_1mcdf(p, k, offset): 282 | if k <= offset: 283 | return 1. 284 | else: 285 | k = k - offset 286 | """P(n >= k)""" 287 | return (1 - p)**max(k - 1, 0) 288 | 289 | 290 | def poisson_sample(lamb, n_samples): 291 | return np.random.poisson(lamb, n_samples) 292 | 293 | 294 | def poisson_1mcdf(lamb, k, offset): 295 | if k <= offset: 296 | return 1. 297 | else: 298 | k = k - offset 299 | """P(n >= k)""" 300 | sum = 1. 301 | for i in range(1, k): 302 | sum += lamb**i / math.factorial(i) 303 | return 1 - np.exp(-lamb) * sum 304 | 305 | 306 | def sample_rademacher_like(y): 307 | return torch.randint(low=0, high=2, size=y.shape).to(y) * 2 - 1 308 | 309 | 310 | # -------------- Helper functions -------------- 311 | 312 | 313 | def safe_detach(tensor): 314 | return tensor.detach().requires_grad_(tensor.requires_grad) 315 | 316 | 317 | def _flatten(sequence): 318 | flat = [p.reshape(-1) for p in sequence] 319 | return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) 320 | 321 | 322 | def _flatten_convert_none_to_zeros(sequence, like_sequence): 323 | flat = [p.reshape(-1) if p is not None else torch.zeros_like(q).view(-1) for p, q in zip(sequence, like_sequence)] 324 | return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) 325 | -------------------------------------------------------------------------------- /lib/layers/mask_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _get_checkerboard_mask(x, swap=False): 5 | n, c, h, w = x.size() 6 | 7 | H = ((h - 1) // 2 + 1) * 2 # H = h + 1 if h is odd and h if h is even 8 | W = ((w - 1) // 2 + 1) * 2 9 | 10 | # construct checkerboard mask 11 | if not swap: 12 | mask = torch.Tensor([[1, 0], [0, 1]]).repeat(H // 2, W // 2) 13 | else: 14 | mask = torch.Tensor([[0, 1], [1, 0]]).repeat(H // 2, W // 2) 15 | mask = mask[:h, :w] 16 | mask = mask.contiguous().view(1, 1, h, w).expand(n, c, h, w).type_as(x.data) 17 | 18 | return mask 19 | 20 | 21 | def _get_channel_mask(x, swap=False): 22 | n, c, h, w = x.size() 23 | assert (c % 2 == 0) 24 | 25 | # construct channel-wise mask 26 | mask = torch.zeros(x.size()) 27 | if not swap: 28 | mask[:, :c // 2] = 1 29 | else: 30 | mask[:, c // 2:] = 1 31 | return mask 32 | 33 | 34 | def get_mask(x, mask_type=None): 35 | if mask_type is None: 36 | return torch.zeros(x.size()).to(x) 37 | elif mask_type == 'channel0': 38 | return _get_channel_mask(x, swap=False) 39 | elif mask_type == 'channel1': 40 | return _get_channel_mask(x, swap=True) 41 | elif mask_type == 'checkerboard0': 42 | return _get_checkerboard_mask(x, swap=False) 43 | elif mask_type == 'checkerboard1': 44 | return _get_checkerboard_mask(x, swap=True) 45 | else: 46 | raise ValueError('Unknown mask type {}'.format(mask_type)) 47 | -------------------------------------------------------------------------------- /lib/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | 5 | __all__ = ['MovingBatchNorm1d', 'MovingBatchNorm2d'] 6 | 7 | 8 | class MovingBatchNormNd(nn.Module): 9 | 10 | def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True): 11 | super(MovingBatchNormNd, self).__init__() 12 | self.num_features = num_features 13 | self.affine = affine 14 | self.eps = eps 15 | self.decay = decay 16 | self.bn_lag = bn_lag 17 | self.register_buffer('step', torch.zeros(1)) 18 | if self.affine: 19 | self.bias = Parameter(torch.Tensor(num_features)) 20 | else: 21 | self.register_parameter('bias', None) 22 | self.register_buffer('running_mean', torch.zeros(num_features)) 23 | self.reset_parameters() 24 | 25 | @property 26 | def shape(self): 27 | raise NotImplementedError 28 | 29 | def reset_parameters(self): 30 | self.running_mean.zero_() 31 | if self.affine: 32 | self.bias.data.zero_() 33 | 34 | def forward(self, x, logpx=None): 35 | c = x.size(1) 36 | used_mean = self.running_mean.clone().detach() 37 | 38 | if self.training: 39 | # compute batch statistics 40 | x_t = x.transpose(0, 1).contiguous().view(c, -1) 41 | batch_mean = torch.mean(x_t, dim=1) 42 | 43 | # moving average 44 | if self.bn_lag > 0: 45 | used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach()) 46 | used_mean /= (1. - self.bn_lag**(self.step[0] + 1)) 47 | 48 | # update running estimates 49 | self.running_mean -= self.decay * (self.running_mean - batch_mean.data) 50 | self.step += 1 51 | 52 | # perform normalization 53 | used_mean = used_mean.view(*self.shape).expand_as(x) 54 | 55 | y = x - used_mean 56 | 57 | if self.affine: 58 | bias = self.bias.view(*self.shape).expand_as(x) 59 | y = y + bias 60 | 61 | if logpx is None: 62 | return y 63 | else: 64 | return y, logpx 65 | 66 | def inverse(self, y, logpy=None): 67 | used_mean = self.running_mean 68 | 69 | if self.affine: 70 | bias = self.bias.view(*self.shape).expand_as(y) 71 | y = y - bias 72 | 73 | used_mean = used_mean.view(*self.shape).expand_as(y) 74 | x = y + used_mean 75 | 76 | if logpy is None: 77 | return x 78 | else: 79 | return x, logpy 80 | 81 | def __repr__(self): 82 | return ( 83 | '{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},' 84 | ' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__) 85 | ) 86 | 87 | 88 | class MovingBatchNorm1d(MovingBatchNormNd): 89 | 90 | @property 91 | def shape(self): 92 | return [1, -1] 93 | 94 | 95 | class MovingBatchNorm2d(MovingBatchNormNd): 96 | 97 | @property 98 | def shape(self): 99 | return [1, -1, 1, 1] 100 | -------------------------------------------------------------------------------- /lib/layers/squeeze.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = ['SqueezeLayer'] 5 | 6 | 7 | class SqueezeLayer(nn.Module): 8 | 9 | def __init__(self, downscale_factor): 10 | super(SqueezeLayer, self).__init__() 11 | self.downscale_factor = downscale_factor 12 | 13 | def forward(self, x, logpx=None, restore=False): 14 | squeeze_x = squeeze(x, self.downscale_factor) 15 | if logpx is None: 16 | return squeeze_x 17 | else: 18 | return squeeze_x, logpx 19 | 20 | def inverse(self, y, logpy=None): 21 | unsqueeze_y = unsqueeze(y, self.downscale_factor) 22 | if logpy is None: 23 | return unsqueeze_y 24 | else: 25 | return unsqueeze_y, logpy 26 | 27 | 28 | def unsqueeze(input, upscale_factor=2): 29 | return torch.pixel_shuffle(input, upscale_factor) 30 | 31 | 32 | def squeeze(input, downscale_factor=2): 33 | ''' 34 | [:, C, H*r, W*r] -> [:, C*r^2, H, W] 35 | ''' 36 | batch_size, in_channels, in_height, in_width = input.shape 37 | out_channels = in_channels * (downscale_factor**2) 38 | 39 | out_height = in_height // downscale_factor 40 | out_width = in_width // downscale_factor 41 | 42 | input_view = input.reshape(batch_size, in_channels, out_height, downscale_factor, out_width, downscale_factor) 43 | 44 | output = input_view.permute(0, 1, 3, 5, 2, 4) 45 | return output.reshape(batch_size, out_channels, out_height, out_width) 46 | -------------------------------------------------------------------------------- /lib/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | 5 | class CosineAnnealingWarmRestarts(_LRScheduler): 6 | r"""Set the learning rate of each parameter group using a cosine annealing 7 | schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` 8 | is the number of epochs since the last restart and :math:`T_{i}` is the number 9 | of epochs between two warm restarts in SGDR: 10 | .. math:: 11 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + 12 | \cos(\frac{T_{cur}}{T_{i}}\pi)) 13 | When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. 14 | When :math:`T_{cur}=0`(after restart), set :math:`\eta_t=\eta_{max}`. 15 | It has been proposed in 16 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. 17 | Args: 18 | optimizer (Optimizer): Wrapped optimizer. 19 | T_0 (int): Number of iterations for the first restart. 20 | T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. 21 | eta_min (float, optional): Minimum learning rate. Default: 0. 22 | last_epoch (int, optional): The index of last epoch. Default: -1. 23 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 24 | https://arxiv.org/abs/1608.03983 25 | """ 26 | 27 | def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1): 28 | if T_0 <= 0 or not isinstance(T_0, int): 29 | raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) 30 | if T_mult < 1 or not isinstance(T_mult, int): 31 | raise ValueError("Expected integer T_mul >= 1, but got {}".format(T_mult)) 32 | self.T_0 = T_0 33 | self.T_i = T_0 34 | self.T_mult = T_mult 35 | self.eta_min = eta_min 36 | super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch) 37 | self.T_cur = last_epoch 38 | 39 | def get_lr(self): 40 | return [ 41 | self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2 42 | for base_lr in self.base_lrs 43 | ] 44 | 45 | def step(self, epoch=None): 46 | """Step could be called after every update, i.e. if one epoch has 10 iterations 47 | (number_of_train_examples / batch_size), we should call SGDR.step(0.1), SGDR.step(0.2), etc. 48 | This function can be called in an interleaved way. 49 | Example: 50 | >>> scheduler = SGDR(optimizer, T_0, T_mult) 51 | >>> for epoch in range(20): 52 | >>> scheduler.step() 53 | >>> scheduler.step(26) 54 | >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) 55 | """ 56 | if epoch is None: 57 | epoch = self.last_epoch + 1 58 | self.T_cur = self.T_cur + 1 59 | if self.T_cur >= self.T_i: 60 | self.T_cur = self.T_cur - self.T_i 61 | self.T_i = self.T_i * self.T_mult 62 | else: 63 | if epoch >= self.T_0: 64 | if self.T_mult == 1: 65 | self.T_cur = epoch % self.T_0 66 | else: 67 | n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) 68 | self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / (self.T_mult - 1) 69 | self.T_i = self.T_0 * self.T_mult**(n) 70 | else: 71 | self.T_i = self.T_0 72 | self.T_cur = epoch 73 | self.last_epoch = math.floor(epoch) 74 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 75 | param_group['lr'] = lr 76 | -------------------------------------------------------------------------------- /lib/optimizers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class Adam(Optimizer): 7 | """Implements Adam algorithm. 8 | 9 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 10 | 11 | Arguments: 12 | params (iterable): iterable of parameters to optimize or dicts defining 13 | parameter groups 14 | lr (float, optional): learning rate (default: 1e-3) 15 | betas (Tuple[float, float], optional): coefficients used for computing 16 | running averages of gradient and its square (default: (0.9, 0.999)) 17 | eps (float, optional): term added to the denominator to improve 18 | numerical stability (default: 1e-8) 19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 20 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 21 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 22 | (default: False) 23 | 24 | .. _Adam\: A Method for Stochastic Optimization: 25 | https://arxiv.org/abs/1412.6980 26 | .. _On the Convergence of Adam and Beyond: 27 | https://openreview.net/forum?id=ryQu7f-RZ 28 | """ 29 | 30 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False): 31 | if not 0.0 <= lr: 32 | raise ValueError("Invalid learning rate: {}".format(lr)) 33 | if not 0.0 <= eps: 34 | raise ValueError("Invalid epsilon value: {}".format(eps)) 35 | if not 0.0 <= betas[0] < 1.0: 36 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 37 | if not 0.0 <= betas[1] < 1.0: 38 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 39 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) 40 | super(Adam, self).__init__(params, defaults) 41 | 42 | def __setstate__(self, state): 43 | super(Adam, self).__setstate__(state) 44 | for group in self.param_groups: 45 | group.setdefault('amsgrad', False) 46 | 47 | def step(self, closure=None): 48 | """Performs a single optimization step. 49 | 50 | Arguments: 51 | closure (callable, optional): A closure that reevaluates the model 52 | and returns the loss. 53 | """ 54 | loss = None 55 | if closure is not None: 56 | loss = closure() 57 | 58 | for group in self.param_groups: 59 | for p in group['params']: 60 | if p.grad is None: 61 | continue 62 | grad = p.grad.data 63 | if grad.is_sparse: 64 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 65 | amsgrad = group['amsgrad'] 66 | 67 | state = self.state[p] 68 | 69 | # State initialization 70 | if len(state) == 0: 71 | state['step'] = 0 72 | # Exponential moving average of gradient values 73 | state['exp_avg'] = torch.zeros_like(p.data) 74 | # Exponential moving average of squared gradient values 75 | state['exp_avg_sq'] = torch.zeros_like(p.data) 76 | if amsgrad: 77 | # Maintains max of all exp. moving avg. of sq. grad. values 78 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 79 | 80 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 81 | if amsgrad: 82 | max_exp_avg_sq = state['max_exp_avg_sq'] 83 | beta1, beta2 = group['betas'] 84 | 85 | state['step'] += 1 86 | 87 | # Decay the first and second moment running average coefficient 88 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 89 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 90 | if amsgrad: 91 | # Maintains the maximum of all 2nd moment running avg. till now 92 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 93 | # Use the max. for normalizing running avg. of gradient 94 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 95 | else: 96 | denom = exp_avg_sq.sqrt().add_(group['eps']) 97 | 98 | bias_correction1 = 1 - beta1**state['step'] 99 | bias_correction2 = 1 - beta2**state['step'] 100 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 101 | 102 | p.data.addcdiv_(-step_size, exp_avg, denom) 103 | 104 | if group['weight_decay'] != 0: 105 | p.data.add(-step_size * group['weight_decay'], p.data) 106 | 107 | return loss 108 | 109 | 110 | class Adamax(Optimizer): 111 | """Implements Adamax algorithm (a variant of Adam based on infinity norm). 112 | 113 | It has been proposed in `Adam: A Method for Stochastic Optimization`__. 114 | 115 | Arguments: 116 | params (iterable): iterable of parameters to optimize or dicts defining 117 | parameter groups 118 | lr (float, optional): learning rate (default: 2e-3) 119 | betas (Tuple[float, float], optional): coefficients used for computing 120 | running averages of gradient and its square 121 | eps (float, optional): term added to the denominator to improve 122 | numerical stability (default: 1e-8) 123 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 124 | 125 | __ https://arxiv.org/abs/1412.6980 126 | """ 127 | 128 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 129 | if not 0.0 <= lr: 130 | raise ValueError("Invalid learning rate: {}".format(lr)) 131 | if not 0.0 <= eps: 132 | raise ValueError("Invalid epsilon value: {}".format(eps)) 133 | if not 0.0 <= betas[0] < 1.0: 134 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 135 | if not 0.0 <= betas[1] < 1.0: 136 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 137 | if not 0.0 <= weight_decay: 138 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 139 | 140 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 141 | super(Adamax, self).__init__(params, defaults) 142 | 143 | def step(self, closure=None): 144 | """Performs a single optimization step. 145 | 146 | Arguments: 147 | closure (callable, optional): A closure that reevaluates the model 148 | and returns the loss. 149 | """ 150 | loss = None 151 | if closure is not None: 152 | loss = closure() 153 | 154 | for group in self.param_groups: 155 | for p in group['params']: 156 | if p.grad is None: 157 | continue 158 | grad = p.grad.data 159 | if grad.is_sparse: 160 | raise RuntimeError('Adamax does not support sparse gradients') 161 | state = self.state[p] 162 | 163 | # State initialization 164 | if len(state) == 0: 165 | state['step'] = 0 166 | state['exp_avg'] = torch.zeros_like(p.data) 167 | state['exp_inf'] = torch.zeros_like(p.data) 168 | 169 | exp_avg, exp_inf = state['exp_avg'], state['exp_inf'] 170 | beta1, beta2 = group['betas'] 171 | eps = group['eps'] 172 | 173 | state['step'] += 1 174 | 175 | # Update biased first moment estimate. 176 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 177 | # Update the exponentially weighted infinity norm. 178 | norm_buf = torch.cat([exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)], 0) 179 | torch.max(norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long())) 180 | 181 | bias_correction = 1 - beta1**state['step'] 182 | clr = group['lr'] / bias_correction 183 | 184 | p.data.addcdiv_(-clr, exp_avg, exp_inf) 185 | 186 | if group['weight_decay'] != 0: 187 | p.data.add(-clr * group['weight_decay'], p.data) 188 | 189 | return loss 190 | 191 | 192 | class RMSprop(Optimizer): 193 | """Implements RMSprop algorithm. 194 | 195 | Proposed by G. Hinton in his 196 | `course `_. 197 | 198 | The centered version first appears in `Generating Sequences 199 | With Recurrent Neural Networks `_. 200 | 201 | Arguments: 202 | params (iterable): iterable of parameters to optimize or dicts defining 203 | parameter groups 204 | lr (float, optional): learning rate (default: 1e-2) 205 | momentum (float, optional): momentum factor (default: 0) 206 | alpha (float, optional): smoothing constant (default: 0.99) 207 | eps (float, optional): term added to the denominator to improve 208 | numerical stability (default: 1e-8) 209 | centered (bool, optional) : if ``True``, compute the centered RMSProp, 210 | the gradient is normalized by an estimation of its variance 211 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 212 | 213 | """ 214 | 215 | def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False): 216 | if not 0.0 <= lr: 217 | raise ValueError("Invalid learning rate: {}".format(lr)) 218 | if not 0.0 <= eps: 219 | raise ValueError("Invalid epsilon value: {}".format(eps)) 220 | if not 0.0 <= momentum: 221 | raise ValueError("Invalid momentum value: {}".format(momentum)) 222 | if not 0.0 <= weight_decay: 223 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 224 | if not 0.0 <= alpha: 225 | raise ValueError("Invalid alpha value: {}".format(alpha)) 226 | 227 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay) 228 | super(RMSprop, self).__init__(params, defaults) 229 | 230 | def __setstate__(self, state): 231 | super(RMSprop, self).__setstate__(state) 232 | for group in self.param_groups: 233 | group.setdefault('momentum', 0) 234 | group.setdefault('centered', False) 235 | 236 | def step(self, closure=None): 237 | """Performs a single optimization step. 238 | 239 | Arguments: 240 | closure (callable, optional): A closure that reevaluates the model 241 | and returns the loss. 242 | """ 243 | loss = None 244 | if closure is not None: 245 | loss = closure() 246 | 247 | for group in self.param_groups: 248 | for p in group['params']: 249 | if p.grad is None: 250 | continue 251 | grad = p.grad.data 252 | if grad.is_sparse: 253 | raise RuntimeError('RMSprop does not support sparse gradients') 254 | state = self.state[p] 255 | 256 | # State initialization 257 | if len(state) == 0: 258 | state['step'] = 0 259 | state['square_avg'] = torch.zeros_like(p.data) 260 | if group['momentum'] > 0: 261 | state['momentum_buffer'] = torch.zeros_like(p.data) 262 | if group['centered']: 263 | state['grad_avg'] = torch.zeros_like(p.data) 264 | 265 | square_avg = state['square_avg'] 266 | alpha = group['alpha'] 267 | 268 | state['step'] += 1 269 | 270 | square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) 271 | 272 | if group['centered']: 273 | grad_avg = state['grad_avg'] 274 | grad_avg.mul_(alpha).add_(1 - alpha, grad) 275 | avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt().add_(group['eps']) 276 | else: 277 | avg = square_avg.sqrt().add_(group['eps']) 278 | 279 | if group['momentum'] > 0: 280 | buf = state['momentum_buffer'] 281 | buf.mul_(group['momentum']).addcdiv_(grad, avg) 282 | p.data.add_(-group['lr'], buf) 283 | else: 284 | p.data.addcdiv_(-group['lr'], grad, avg) 285 | 286 | if group['weight_decay'] != 0: 287 | p.data.add(-group['lr'] * group['weight_decay'], p.data) 288 | 289 | return loss 290 | -------------------------------------------------------------------------------- /lib/tabular.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from collections import Counter 4 | 5 | import numpy as np 6 | 7 | import torch 8 | 9 | import pandas 10 | 11 | import h5py 12 | 13 | 14 | class SupervisedDataset(torch.utils.data.Dataset): 15 | def __init__(self, name, role, x, y=None): 16 | if y is None: 17 | y = torch.zeros(x.shape[0]).long() 18 | 19 | assert x.shape[0] == y.shape[0] 20 | assert role in ["train", "valid", "test"] 21 | 22 | self.name = name 23 | self.role = role 24 | 25 | self.x = x 26 | self.y = y 27 | 28 | def __len__(self): 29 | return self.x.shape[0] 30 | 31 | def __getitem__(self, index): 32 | return self.x[index], self.y[index] 33 | 34 | def to(self, device): 35 | return SupervisedDataset( 36 | self.name, 37 | self.role, 38 | self.x.to(device), 39 | self.y.to(device) 40 | ) 41 | 42 | 43 | def normalize_raw_data(data, mu, s): 44 | return (data - mu)/s 45 | 46 | 47 | def make_tabular_train_valid_split(data, frac): 48 | n_valid = int(frac*data.shape[0]) 49 | valid_data = data[-n_valid:] 50 | train_data = data[0:-n_valid] 51 | return train_data, valid_data 52 | 53 | 54 | def make_tabular_train_valid_test_split(data, frac): 55 | n_test = int(frac*data.shape[0]) 56 | test_data = data[-n_test:] 57 | data = data[0:-n_test] 58 | 59 | train_data, valid_data = make_tabular_train_valid_split(data, frac) 60 | return train_data, valid_data, test_data 61 | 62 | 63 | def get_miniboone_raw(data_root): 64 | data = np.load(os.path.join(data_root, "miniboone/data.npy")) 65 | 66 | train_raw, valid_raw, test_raw = make_tabular_train_valid_test_split(data, 0.1) 67 | 68 | data_stack = np.vstack((train_raw, valid_raw)) 69 | mu = data_stack.mean(axis=0) 70 | s = data_stack.std(axis=0) 71 | 72 | train_raw = normalize_raw_data(train_raw, mu, s) 73 | valid_raw = normalize_raw_data(valid_raw, mu, s) 74 | test_raw = normalize_raw_data(test_raw, mu, s) 75 | 76 | return train_raw, valid_raw, test_raw 77 | 78 | 79 | def get_gas_raw(data_root): 80 | 81 | def get_gas_correlation_numbers(data): 82 | C = data.corr() 83 | A = C > 0.98 84 | B = A.to_numpy().sum(axis=1) 85 | return B 86 | 87 | data = pandas.read_pickle(os.path.join(data_root, "gas/ethylene_CO.pickle")) 88 | data.drop("Meth", axis=1, inplace=True) 89 | data.drop("Eth", axis=1, inplace=True) 90 | data.drop("Time", axis=1, inplace=True) 91 | 92 | B = get_gas_correlation_numbers(data) 93 | while np.any(B > 1): 94 | col_to_remove = np.where(B > 1)[0][0] 95 | col_name = data.columns[col_to_remove] 96 | data.drop(col_name, axis=1, inplace=True) 97 | B = get_gas_correlation_numbers(data) 98 | 99 | data = normalize_raw_data(data, data.mean(), data.std()).to_numpy() 100 | return make_tabular_train_valid_test_split(data, 0.1) 101 | 102 | 103 | def get_hepmass_raw(data_root): 104 | train_data_path = os.path.join(data_root, "hepmass/1000_train.csv") 105 | test_data_path = os.path.join(data_root, "hepmass/1000_test.csv") 106 | 107 | train_raw = pandas.read_csv(filepath_or_buffer=train_data_path, index_col=False) 108 | test_raw = pandas.read_csv(filepath_or_buffer=test_data_path, index_col=False) 109 | 110 | train_raw = train_raw[train_raw[train_raw.columns[0]] == 1] 111 | train_raw = train_raw.drop(train_raw.columns[0], axis=1) 112 | 113 | test_raw = test_raw[test_raw[test_raw.columns[0]] == 1] 114 | test_raw = test_raw.drop(test_raw.columns[0], axis=1) 115 | test_raw = test_raw.drop(test_raw.columns[-1], axis=1) 116 | 117 | mu = train_raw.mean() 118 | s = train_raw.std() 119 | train_raw = normalize_raw_data(train_raw, mu, s).to_numpy() 120 | test_raw = normalize_raw_data(test_raw, mu, s).to_numpy() 121 | 122 | i = 0 123 | features_to_remove = [] 124 | for feature in train_raw.T: 125 | c = Counter(feature) 126 | max_count = np.array([v for k, v in sorted(c.items())])[0] 127 | if max_count > 5: 128 | features_to_remove.append(i) 129 | i += 1 130 | train_raw = train_raw[:, np.array([i for i in range(train_raw.shape[1]) if i not in features_to_remove])] 131 | test_raw = test_raw[:, np.array([i for i in range(test_raw.shape[1]) if i not in features_to_remove])] 132 | 133 | train_raw, valid_raw = make_tabular_train_valid_split(train_raw, 0.1) 134 | return train_raw, valid_raw, test_raw 135 | 136 | 137 | def get_power_raw(data_root): 138 | data = np.load(os.path.join(data_root, "power/data.npy")) 139 | np.random.shuffle(data) 140 | n = data.shape[0] 141 | 142 | data = np.delete(data, 3, axis=1) 143 | data = np.delete(data, 1, axis=1) 144 | 145 | gap_noise = 0.001*np.random.rand(n, 1) 146 | voltage_noise = 0.01*np.random.rand(n, 1) 147 | sm_noise = np.random.rand(n, 3) 148 | time_noise = np.zeros((n, 1)) 149 | 150 | noise = np.hstack((gap_noise, voltage_noise, sm_noise, time_noise)) 151 | data = data + noise 152 | 153 | train_raw, valid_raw, test_raw = make_tabular_train_valid_test_split(data, 0.1) 154 | 155 | train_and_valid = np.vstack((train_raw, valid_raw)) 156 | mu = train_and_valid.mean(axis=0) 157 | s = train_and_valid.std(axis=0) 158 | 159 | train_raw = normalize_raw_data(train_raw, mu, s) 160 | valid_raw = normalize_raw_data(valid_raw, mu, s) 161 | test_raw = normalize_raw_data(test_raw, mu, s) 162 | 163 | return train_raw, valid_raw, test_raw 164 | 165 | 166 | def get_bsds300_raw(data_root): 167 | with h5py.File(os.path.join(data_root, "BSDS300", "BSDS300.hdf5"), "r") as f: 168 | train_raw = f["train"][()] 169 | valid_raw = f["validation"][()] 170 | test_raw = f["test"][()] 171 | return train_raw, valid_raw, test_raw 172 | 173 | 174 | def get_raw_tabular_datasets(name, data_root): 175 | if name == "miniboone": 176 | data_fn = get_miniboone_raw 177 | elif name == "gas": 178 | data_fn = get_gas_raw 179 | elif name == "hepmass": 180 | data_fn = get_hepmass_raw 181 | elif name == "power": 182 | data_fn = get_power_raw 183 | elif name == "bsds300": 184 | data_fn = get_bsds300_raw 185 | else: 186 | raise NotImplementedError 187 | 188 | return data_fn(data_root) 189 | 190 | 191 | def get_tabular_datasets(name, data_root): 192 | train_raw, valid_raw, test_raw = get_raw_tabular_datasets(name, data_root) 193 | print(train_raw.shape) 194 | 195 | train_dset = SupervisedDataset(name, "train", 196 | torch.tensor(train_raw, dtype=torch.get_default_dtype())) 197 | valid_dset = SupervisedDataset(name, "valid", 198 | torch.tensor(valid_raw, dtype=torch.get_default_dtype())) 199 | test_dset = SupervisedDataset(name, "test", 200 | torch.tensor(test_raw, dtype=torch.get_default_dtype())) 201 | 202 | return train_dset, valid_dset, test_dset -------------------------------------------------------------------------------- /lib/toy_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn 3 | import sklearn.datasets 4 | from sklearn.utils import shuffle as util_shuffle 5 | 6 | 7 | # Dataset iterator 8 | def inf_train_gen(data, batch_size=200): 9 | 10 | if data == "swissroll": 11 | data = sklearn.datasets.make_swiss_roll(n_samples=batch_size, noise=1.0)[0] 12 | data = data.astype("float32")[:, [0, 2]] 13 | data /= 5 14 | return data 15 | 16 | elif data == "circles": 17 | data = sklearn.datasets.make_circles(n_samples=batch_size, factor=.5, noise=0.08)[0] 18 | data = data.astype("float32") 19 | data *= 3 20 | return data 21 | 22 | elif data == "rings": 23 | n_samples4 = n_samples3 = n_samples2 = batch_size // 4 24 | n_samples1 = batch_size - n_samples4 - n_samples3 - n_samples2 25 | 26 | # so as not to have the first point = last point, we set endpoint=False 27 | linspace4 = np.linspace(0, 2 * np.pi, n_samples4, endpoint=False) 28 | linspace3 = np.linspace(0, 2 * np.pi, n_samples3, endpoint=False) 29 | linspace2 = np.linspace(0, 2 * np.pi, n_samples2, endpoint=False) 30 | linspace1 = np.linspace(0, 2 * np.pi, n_samples1, endpoint=False) 31 | 32 | circ4_x = np.cos(linspace4) 33 | circ4_y = np.sin(linspace4) 34 | circ3_x = np.cos(linspace4) * 0.75 35 | circ3_y = np.sin(linspace3) * 0.75 36 | circ2_x = np.cos(linspace2) * 0.5 37 | circ2_y = np.sin(linspace2) * 0.5 38 | circ1_x = np.cos(linspace1) * 0.25 39 | circ1_y = np.sin(linspace1) * 0.25 40 | 41 | X = np.vstack([ 42 | np.hstack([circ4_x, circ3_x, circ2_x, circ1_x]), 43 | np.hstack([circ4_y, circ3_y, circ2_y, circ1_y]) 44 | ]).T * 3.0 45 | X = util_shuffle(X) 46 | 47 | # Add noise 48 | X = X + np.random.normal(scale=0.08, size=X.shape) 49 | 50 | return X.astype("float32") 51 | 52 | elif data == "moons": 53 | data = sklearn.datasets.make_moons(n_samples=batch_size, noise=0.1)[0] 54 | data = data.astype("float32") 55 | data = data * 2 + np.array([-1, -0.2]) 56 | return data 57 | 58 | elif data == "8gaussians": 59 | scale = 4. 60 | centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)), 61 | (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2), 62 | 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))] 63 | centers = [(scale * x, scale * y) for x, y in centers] 64 | 65 | dataset = [] 66 | for i in range(batch_size): 67 | point = np.random.randn(2) * 0.5 68 | idx = np.random.randint(8) 69 | center = centers[idx] 70 | point[0] += center[0] 71 | point[1] += center[1] 72 | dataset.append(point) 73 | dataset = np.array(dataset, dtype="float32") 74 | dataset /= 1.414 75 | return dataset 76 | 77 | elif data == "pinwheel": 78 | radial_std = 0.3 79 | tangential_std = 0.1 80 | num_classes = 5 81 | num_per_class = batch_size // 5 82 | rate = 0.25 83 | rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False) 84 | 85 | features = np.random.randn(num_classes*num_per_class, 2) \ 86 | * np.array([radial_std, tangential_std]) 87 | features[:, 0] += 1. 88 | labels = np.repeat(np.arange(num_classes), num_per_class) 89 | 90 | angles = rads[labels] + rate * np.exp(features[:, 0]) 91 | rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)]) 92 | rotations = np.reshape(rotations.T, (-1, 2, 2)) 93 | 94 | return 2 * np.random.permutation(np.einsum("ti,tij->tj", features, rotations)) 95 | 96 | elif data == "2spirals": 97 | n = np.sqrt(np.random.rand(batch_size // 2, 1)) * 540 * (2 * np.pi) / 360 98 | d1x = -np.cos(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 99 | d1y = np.sin(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 100 | x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3 101 | x += np.random.randn(*x.shape) * 0.1 102 | return x 103 | 104 | elif data == "checkerboard": 105 | x1 = np.random.rand(batch_size) * 4 - 2 106 | x2_ = np.random.rand(batch_size) - np.random.randint(0, 2, batch_size) * 2 107 | x2 = x2_ + (np.floor(x1) % 2) 108 | return np.concatenate([x1[:, None], x2[:, None]], 1) * 2 109 | 110 | elif data == "line": 111 | x = np.random.rand(batch_size) * 5 - 2.5 112 | y = x 113 | return np.stack((x, y), 1) 114 | elif data == "cos": 115 | x = np.random.rand(batch_size) * 5 - 2.5 116 | y = np.sin(x) * 2.5 117 | return np.stack((x, y), 1) 118 | else: 119 | return inf_train_gen("8gaussians", batch_size) 120 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from numbers import Number 4 | import logging 5 | import torch 6 | 7 | 8 | def makedirs(dirname): 9 | if not os.path.exists(dirname): 10 | os.makedirs(dirname) 11 | 12 | 13 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False): 14 | logger = logging.getLogger() 15 | if debug: 16 | level = logging.DEBUG 17 | else: 18 | level = logging.INFO 19 | logger.setLevel(level) 20 | if saving: 21 | info_file_handler = logging.FileHandler(logpath, mode="a") 22 | info_file_handler.setLevel(level) 23 | logger.addHandler(info_file_handler) 24 | if displaying: 25 | console_handler = logging.StreamHandler() 26 | console_handler.setLevel(level) 27 | logger.addHandler(console_handler) 28 | logger.info(filepath) 29 | with open(filepath, "r") as f: 30 | logger.info(f.read()) 31 | 32 | for f in package_files: 33 | logger.info(f) 34 | with open(f, "r") as package_f: 35 | logger.info(package_f.read()) 36 | 37 | return logger 38 | 39 | 40 | class AverageMeter(object): 41 | """Computes and stores the average and current value""" 42 | 43 | def __init__(self): 44 | self.reset() 45 | 46 | def reset(self): 47 | self.val = 0 48 | self.avg = 0 49 | self.sum = 0 50 | self.count = 0 51 | 52 | def update(self, val, n=1): 53 | self.val = val 54 | self.sum += val * n 55 | self.count += n 56 | self.avg = self.sum / self.count 57 | 58 | 59 | class RunningAverageMeter(object): 60 | """Computes and stores the average and current value""" 61 | 62 | def __init__(self, momentum=0.99): 63 | self.momentum = momentum 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = None 68 | self.avg = 0 69 | 70 | def update(self, val): 71 | if self.val is None: 72 | self.avg = val 73 | else: 74 | self.avg = self.avg * self.momentum + val * (1 - self.momentum) 75 | self.val = val 76 | 77 | 78 | def inf_generator(iterable): 79 | """Allows training with DataLoaders in a single infinite loop: 80 | for i, (x, y) in enumerate(inf_generator(train_loader)): 81 | """ 82 | iterator = iterable.__iter__() 83 | while True: 84 | try: 85 | yield iterator.__next__() 86 | except StopIteration: 87 | iterator = iterable.__iter__() 88 | 89 | 90 | def save_checkpoint(state, save, epoch, last_checkpoints=None, num_checkpoints=None): 91 | if not os.path.exists(save): 92 | os.makedirs(save) 93 | filename = os.path.join(save, 'checkpt-%04d.pth' % epoch) 94 | torch.save(state, filename) 95 | 96 | if last_checkpoints is not None and num_checkpoints is not None: 97 | last_checkpoints.append(epoch) 98 | if len(last_checkpoints) > num_checkpoints: 99 | rm_epoch = last_checkpoints.pop(0) 100 | os.remove(os.path.join(save, 'checkpt-%04d.pth' % rm_epoch)) 101 | 102 | 103 | def isnan(tensor): 104 | return (tensor != tensor) 105 | 106 | 107 | def logsumexp(value, dim=None, keepdim=False): 108 | """Numerically stable implementation of the operation 109 | value.exp().sum(dim, keepdim).log() 110 | """ 111 | if dim is not None: 112 | m, _ = torch.max(value, dim=dim, keepdim=True) 113 | value0 = value - m 114 | if keepdim is False: 115 | m = m.squeeze(dim) 116 | return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim)) 117 | else: 118 | m = torch.max(value) 119 | sum_exp = torch.sum(torch.exp(value - m)) 120 | if isinstance(sum_exp, Number): 121 | return m + math.log(sum_exp) 122 | else: 123 | return m + torch.log(sum_exp) 124 | 125 | 126 | class ExponentialMovingAverage(object): 127 | 128 | def __init__(self, module, decay=0.999): 129 | """Initializes the model when .apply() is called the first time. 130 | This is to take into account data-dependent initialization that occurs in the first iteration.""" 131 | self.module = module 132 | self.decay = decay 133 | self.shadow_params = {} 134 | self.nparams = sum(p.numel() for p in module.parameters()) 135 | 136 | def init(self): 137 | for name, param in self.module.named_parameters(): 138 | self.shadow_params[name] = param.data.clone() 139 | 140 | def apply(self): 141 | if len(self.shadow_params) == 0: 142 | self.init() 143 | else: 144 | with torch.no_grad(): 145 | for name, param in self.module.named_parameters(): 146 | self.shadow_params[name] -= (1 - self.decay) * (self.shadow_params[name] - param.data) 147 | 148 | def set(self, other_ema): 149 | self.init() 150 | with torch.no_grad(): 151 | for name, param in other_ema.shadow_params.items(): 152 | self.shadow_params[name].copy_(param) 153 | 154 | def replace_with_ema(self): 155 | for name, param in self.module.named_parameters(): 156 | param.data.copy_(self.shadow_params[name]) 157 | 158 | def swap(self): 159 | for name, param in self.module.named_parameters(): 160 | tmp = self.shadow_params[name].clone() 161 | self.shadow_params[name].copy_(param.data) 162 | param.data.copy_(tmp) 163 | 164 | def __repr__(self): 165 | return ( 166 | '{}(decay={}, module={}, nparams={})'.format( 167 | self.__class__.__name__, self.decay, self.module.__class__.__name__, self.nparams 168 | ) 169 | ) 170 | -------------------------------------------------------------------------------- /lib/visualize_flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | matplotlib.use("Agg") 4 | import matplotlib.pyplot as plt 5 | import torch 6 | 7 | LOW = -4 8 | HIGH = 4 9 | 10 | 11 | def plt_potential_func(potential, ax, npts=100, title="$p(x)$"): 12 | """ 13 | Args: 14 | potential: computes U(z_k) given z_k 15 | """ 16 | xside = np.linspace(LOW, HIGH, npts) 17 | yside = np.linspace(LOW, HIGH, npts) 18 | xx, yy = np.meshgrid(xside, yside) 19 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 20 | 21 | z = torch.Tensor(z) 22 | u = potential(z).cpu().numpy() 23 | p = np.exp(-u).reshape(npts, npts) 24 | 25 | plt.pcolormesh(xx, yy, p) 26 | ax.invert_yaxis() 27 | ax.get_xaxis().set_ticks([]) 28 | ax.get_yaxis().set_ticks([]) 29 | ax.set_title(title) 30 | 31 | 32 | def plt_flow(prior_logdensity, transform, ax, npts=100, title="$q(x)$", device="cpu"): 33 | """ 34 | Args: 35 | transform: computes z_k and log(q_k) given z_0 36 | """ 37 | side = np.linspace(LOW, HIGH, npts) 38 | xx, yy = np.meshgrid(side, side) 39 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 40 | 41 | z = torch.tensor(z, requires_grad=True).type(torch.float32).to(device) 42 | logqz = prior_logdensity(z) 43 | logqz = torch.sum(logqz, dim=1)[:, None] 44 | z, logqz = transform(z, logqz) 45 | logqz = torch.sum(logqz, dim=1)[:, None] 46 | 47 | xx = z[:, 0].cpu().numpy().reshape(npts, npts) 48 | yy = z[:, 1].cpu().numpy().reshape(npts, npts) 49 | qz = np.exp(logqz.cpu().numpy()).reshape(npts, npts) 50 | 51 | plt.pcolormesh(xx, yy, qz) 52 | ax.set_xlim(LOW, HIGH) 53 | ax.set_ylim(LOW, HIGH) 54 | cmap = matplotlib.cm.get_cmap(None) 55 | ax.set_facecolor(cmap(0.)) 56 | ax.invert_yaxis() 57 | ax.get_xaxis().set_ticks([]) 58 | ax.get_yaxis().set_ticks([]) 59 | ax.set_title(title) 60 | 61 | 62 | def plt_flow_density(prior_logdensity, inverse_transform, ax, npts=100, memory=100, title="$q(x)$", device="cpu"): 63 | side = np.linspace(LOW, HIGH, npts) 64 | xx, yy = np.meshgrid(side, side) 65 | x = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 66 | 67 | x = torch.from_numpy(x).type(torch.float32).to(device) 68 | zeros = torch.zeros(x.shape[0], 1).to(x) 69 | 70 | z, delta_logp = [], [] 71 | inds = torch.arange(0, x.shape[0]).to(torch.int64) 72 | for ii in torch.split(inds, int(memory**2)): 73 | z_, delta_logp_ = inverse_transform(x[ii], zeros[ii]) 74 | z.append(z_) 75 | delta_logp.append(delta_logp_) 76 | z = torch.cat(z, 0) 77 | delta_logp = torch.cat(delta_logp, 0) 78 | 79 | logpz = prior_logdensity(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z) 80 | logpx = logpz - delta_logp 81 | 82 | px = np.exp(logpx.cpu().numpy()).reshape(npts, npts) 83 | 84 | ax.imshow(px, cmap='inferno') 85 | ax.get_xaxis().set_ticks([]) 86 | ax.get_yaxis().set_ticks([]) 87 | ax.set_title(title) 88 | 89 | 90 | def plt_flow_samples(prior_sample, transform, ax, npts=100, memory=100, title="$x ~ q(x)$", device="cpu"): 91 | z = prior_sample(npts * npts, 2).type(torch.float32).to(device) 92 | zk = [] 93 | inds = torch.arange(0, z.shape[0]).to(torch.int64) 94 | for ii in torch.split(inds, int(memory**2)): 95 | zk.append(transform(z[ii])) 96 | zk = torch.cat(zk, 0).cpu().numpy() 97 | ax.hist2d(zk[:, 0], zk[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts, cmap='inferno') 98 | ax.invert_yaxis() 99 | ax.get_xaxis().set_ticks([]) 100 | ax.get_yaxis().set_ticks([]) 101 | ax.set_title(title) 102 | 103 | 104 | def plt_samples(samples, ax, npts=100, title="$x ~ p(x)$"): 105 | ax.hist2d(samples[:, 0], samples[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts, cmap='inferno') 106 | ax.invert_yaxis() 107 | ax.get_xaxis().set_ticks([]) 108 | ax.get_yaxis().set_ticks([]) 109 | ax.set_title(title) 110 | 111 | 112 | def visualize_transform( 113 | potential_or_samples, prior_sample, prior_density, transform=None, inverse_transform=None, samples=True, npts=100, 114 | memory=100, device="cpu" 115 | ): 116 | """Produces visualization for the model density and samples from the model.""" 117 | plt.clf() 118 | ax = plt.subplot(1, 3, 1, aspect="equal") 119 | if samples: 120 | plt_samples(potential_or_samples, ax, npts=npts) 121 | else: 122 | plt_potential_func(potential_or_samples, ax, npts=npts) 123 | 124 | ax = plt.subplot(1, 3, 2, aspect="equal") 125 | if inverse_transform is None: 126 | plt_flow(prior_density, transform, ax, npts=npts, device=device) 127 | else: 128 | plt_flow_density(prior_density, inverse_transform, ax, npts=npts, memory=memory, device=device) 129 | 130 | ax = plt.subplot(1, 3, 3, aspect="equal") 131 | if transform is not None: 132 | plt_flow_samples(prior_sample, transform, ax, npts=npts, memory=memory, device=device) 133 | -------------------------------------------------------------------------------- /preprocessing/convert_to_pth.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | import numpy as np 4 | import torch 5 | 6 | img = torch.tensor(np.load(sys.argv[1])) 7 | img = img.permute(0, 3, 1, 2) 8 | torch.save(img, re.sub('.npy$', '.pth', sys.argv[1])) 9 | -------------------------------------------------------------------------------- /preprocessing/create_imagenet_benchmark_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run the following commands in ~ before running this file 3 | wget http://image-net.org/small/train_64x64.tar 4 | wget http://image-net.org/small/valid_64x64.tar 5 | tar -xvf train_64x64.tar 6 | tar -xvf valid_64x64.tar 7 | wget http://image-net.org/small/train_32x32.tar 8 | wget http://image-net.org/small/valid_32x32.tar 9 | tar -xvf train_32x32.tar 10 | tar -xvf valid_32x32.tar 11 | """ 12 | 13 | import numpy as np 14 | import scipy.ndimage 15 | import os 16 | from os import listdir 17 | from os.path import isfile, join 18 | from tqdm import tqdm 19 | 20 | 21 | def convert_path_to_npy(*, path='train_64x64', outfile='train_64x64.npy'): 22 | assert isinstance(path, str), "Expected a string input for the path" 23 | assert os.path.exists(path), "Input path doesn't exist" 24 | files = [f for f in listdir(path) if isfile(join(path, f))] 25 | print('Number of valid images is:', len(files)) 26 | imgs = [] 27 | for i in tqdm(range(len(files))): 28 | img = scipy.ndimage.imread(join(path, files[i])) 29 | img = img.astype('uint8') 30 | assert np.max(img) <= 255 31 | assert np.min(img) >= 0 32 | assert img.dtype == 'uint8' 33 | assert isinstance(img, np.ndarray) 34 | imgs.append(img) 35 | resolution_x, resolution_y = img.shape[0], img.shape[1] 36 | imgs = np.asarray(imgs).astype('uint8') 37 | assert imgs.shape[1:] == (resolution_x, resolution_y, 3) 38 | assert np.max(imgs) <= 255 39 | assert np.min(imgs) >= 0 40 | print('Total number of images is:', imgs.shape[0]) 41 | print('All assertions done, dumping into npy file') 42 | np.save(outfile, imgs) 43 | 44 | 45 | if __name__ == '__main__': 46 | convert_path_to_npy(path='train_64x64', outfile='train_64x64.npy') 47 | convert_path_to_npy(path='valid_64x64', outfile='valid_64x64.npy') 48 | convert_path_to_npy(path='train_32x32', outfile='train_32x32.npy') 49 | convert_path_to_npy(path='valid_32x32', outfile='valid_32x32.npy') 50 | -------------------------------------------------------------------------------- /preprocessing/extract_celeba_from_tfrecords.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import torch 4 | 5 | sess = tf.InteractiveSession() 6 | 7 | train_imgs = [] 8 | 9 | print('Reading from training set...', flush=True) 10 | for i in range(120): 11 | tfr = 'data/celebahq/celeba-tfr/train/train-r08-s-{:04d}-of-0120.tfrecords'.format(i) 12 | print(tfr, flush=True) 13 | 14 | record_iterator = tf.python_io.tf_record_iterator(tfr) 15 | 16 | for string_record in record_iterator: 17 | example = tf.train.Example() 18 | example.ParseFromString(string_record) 19 | 20 | image_bytes = example.features.feature['data'].bytes_list.value[0] 21 | 22 | img = tf.decode_raw(image_bytes, tf.uint8) 23 | img = tf.reshape(img, [256, 256, 3]) 24 | img = img.eval() 25 | 26 | train_imgs.append(img) 27 | 28 | train_imgs = np.stack(train_imgs) 29 | train_imgs = torch.tensor(train_imgs).permute(0, 3, 1, 2) 30 | torch.save(train_imgs, 'data/celebahq/celeba256_train.pth') 31 | 32 | validation_imgs = [] 33 | for i in range(40): 34 | tfr = 'data/celebahq/celeba-tfr/validation/validation-r08-s-{:04d}-of-0040.tfrecords'.format(i) 35 | print(tfr, flush=True) 36 | 37 | record_iterator = tf.python_io.tf_record_iterator(tfr) 38 | 39 | for string_record in record_iterator: 40 | example = tf.train.Example() 41 | example.ParseFromString(string_record) 42 | 43 | image_bytes = example.features.feature['data'].bytes_list.value[0] 44 | 45 | img = tf.decode_raw(image_bytes, tf.uint8) 46 | img = tf.reshape(img, [256, 256, 3]) 47 | img = img.eval() 48 | 49 | validation_imgs.append(img) 50 | 51 | validation_imgs = np.stack(validation_imgs) 52 | validation_imgs = torch.tensor(validation_imgs).permute(0, 3, 1, 2) 53 | torch.save(validation_imgs, 'data/celebahq/celeba256_validation.pth') 54 | -------------------------------------------------------------------------------- /qualitative_samples.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | from tqdm import tqdm 4 | 5 | import torch 6 | import torchvision.transforms as transforms 7 | from torchvision.utils import save_image 8 | import torchvision.datasets as vdsets 9 | 10 | from lib.iresnet import ACT_FNS, ResidualFlow 11 | import lib.datasets as datasets 12 | import lib.utils as utils 13 | import lib.layers as layers 14 | import lib.layers.base as base_layers 15 | 16 | # Arguments 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | '--data', type=str, default='cifar10', choices=[ 20 | 'mnist', 21 | 'cifar10', 22 | 'celeba', 23 | 'celebahq', 24 | 'celeba_5bit', 25 | 'imagenet32', 26 | 'imagenet64', 27 | ] 28 | ) 29 | parser.add_argument('--dataroot', type=str, default='data') 30 | parser.add_argument('--imagesize', type=int, default=32) 31 | parser.add_argument('--nbits', type=int, default=8) # Only used for celebahq. 32 | 33 | # Sampling parameters. 34 | parser.add_argument('--real', type=eval, choices=[True, False], default=False) 35 | parser.add_argument('--nrow', type=int, default=10) 36 | parser.add_argument('--ncol', type=int, default=10) 37 | parser.add_argument('--temp', type=float, default=1.0) 38 | parser.add_argument('--nbatches', type=int, default=5) 39 | parser.add_argument('--save-each', type=eval, choices=[True, False], default=False) 40 | 41 | parser.add_argument('--block', type=str, choices=['resblock', 'coupling'], default='resblock') 42 | 43 | parser.add_argument('--coeff', type=float, default=0.98) 44 | parser.add_argument('--vnorms', type=str, default='2222') 45 | parser.add_argument('--n-lipschitz-iters', type=int, default=None) 46 | parser.add_argument('--sn-tol', type=float, default=1e-3) 47 | parser.add_argument('--learn-p', type=eval, choices=[True, False], default=False) 48 | 49 | parser.add_argument('--n-power-series', type=int, default=None) 50 | parser.add_argument('--factor-out', type=eval, choices=[True, False], default=False) 51 | parser.add_argument('--n-dist', choices=['geometric', 'poisson'], default='geometric') 52 | parser.add_argument('--n-samples', type=int, default=1) 53 | parser.add_argument('--n-exact-terms', type=int, default=2) 54 | parser.add_argument('--var-reduc-lr', type=float, default=0) 55 | parser.add_argument('--neumann-grad', type=eval, choices=[True, False], default=True) 56 | parser.add_argument('--mem-eff', type=eval, choices=[True, False], default=True) 57 | 58 | parser.add_argument('--act', type=str, choices=ACT_FNS.keys(), default='swish') 59 | parser.add_argument('--idim', type=int, default=512) 60 | parser.add_argument('--nblocks', type=str, default='16-16-16') 61 | parser.add_argument('--squeeze-first', type=eval, default=False, choices=[True, False]) 62 | parser.add_argument('--actnorm', type=eval, default=True, choices=[True, False]) 63 | parser.add_argument('--fc-actnorm', type=eval, default=False, choices=[True, False]) 64 | parser.add_argument('--batchnorm', type=eval, default=False, choices=[True, False]) 65 | parser.add_argument('--dropout', type=float, default=0.) 66 | parser.add_argument('--fc', type=eval, default=False, choices=[True, False]) 67 | parser.add_argument('--kernels', type=str, default='3-1-3') 68 | parser.add_argument('--quadratic', type=eval, choices=[True, False], default=False) 69 | parser.add_argument('--fc-end', type=eval, choices=[True, False], default=True) 70 | parser.add_argument('--fc-idim', type=int, default=128) 71 | parser.add_argument('--preact', type=eval, choices=[True, False], default=True) 72 | parser.add_argument('--padding', type=int, default=0) 73 | parser.add_argument('--first-resblock', type=eval, choices=[True, False], default=True) 74 | 75 | parser.add_argument('--task', type=str, choices=['density'], default='density') 76 | parser.add_argument('--rcrop-pad-mode', type=str, choices=['constant', 'reflect'], default='reflect') 77 | parser.add_argument('--padding-dist', type=str, choices=['uniform', 'gaussian'], default='uniform') 78 | 79 | parser.add_argument('--resume', type=str, required=True) 80 | parser.add_argument('--nworkers', type=int, default=4) 81 | args = parser.parse_args() 82 | 83 | W = args.ncol 84 | H = args.nrow 85 | 86 | args.batchsize = W * H 87 | args.val_batchsize = W * H 88 | 89 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 90 | 91 | if device.type == 'cuda': 92 | print('Found {} CUDA devices.'.format(torch.cuda.device_count())) 93 | for i in range(torch.cuda.device_count()): 94 | props = torch.cuda.get_device_properties(i) 95 | print('{} \t Memory: {:.2f}GB'.format(props.name, props.total_memory / (1024**3))) 96 | else: 97 | print('WARNING: Using device {}'.format(device)) 98 | 99 | 100 | def geometric_logprob(ns, p): 101 | return torch.log(1 - p + 1e-10) * (ns - 1) + torch.log(p + 1e-10) 102 | 103 | 104 | def standard_normal_sample(size): 105 | return torch.randn(size) 106 | 107 | 108 | def standard_normal_logprob(z): 109 | logZ = -0.5 * math.log(2 * math.pi) 110 | return logZ - z.pow(2) / 2 111 | 112 | 113 | def count_parameters(model): 114 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 115 | 116 | 117 | def add_noise(x, nvals=255): 118 | """ 119 | [0, 1] -> [0, nvals] -> add noise -> [0, 1] 120 | """ 121 | noise = x.new().resize_as_(x).uniform_() 122 | x = x * nvals + noise 123 | x = x / (nvals + 1) 124 | return x 125 | 126 | 127 | def update_lr(optimizer, itr): 128 | iter_frac = min(float(itr + 1) / max(args.warmup_iters, 1), 1.0) 129 | lr = args.lr * iter_frac 130 | for param_group in optimizer.param_groups: 131 | param_group["lr"] = lr 132 | 133 | 134 | def add_padding(x): 135 | if args.padding > 0: 136 | u = x.new_empty(x.shape[0], args.padding, x.shape[2], x.shape[3]).uniform_() 137 | logpu = torch.zeros_like(u).sum([1, 2, 3]) 138 | return torch.cat([u, x], dim=1), logpu 139 | else: 140 | return x, torch.zeros(x.shape[0]).to(x) 141 | 142 | 143 | def remove_padding(x): 144 | if args.padding > 0: 145 | return x[:, args.padding:, :, :] 146 | else: 147 | return x 148 | 149 | 150 | def reduce_bits(x): 151 | if args.nbits < 8: 152 | x = x * 255 153 | x = torch.floor(x / 2**(8 - args.nbits)) 154 | x = x / 2**args.nbits 155 | return x 156 | 157 | 158 | def update_lipschitz(model): 159 | for m in model.modules(): 160 | if isinstance(m, base_layers.SpectralNormConv2d) or isinstance(m, base_layers.SpectralNormLinear): 161 | m.compute_weight(update=True) 162 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear): 163 | m.compute_weight(update=True) 164 | 165 | 166 | print('Loading dataset {}'.format(args.data), flush=True) 167 | # Dataset and hyperparameters 168 | if args.data == 'cifar10': 169 | im_dim = 3 170 | n_classes = 10 171 | 172 | if args.task in ['classification', 'hybrid']: 173 | 174 | if args.real: 175 | 176 | # Classification-specific preprocessing. 177 | transform_train = transforms.Compose([ 178 | transforms.Resize(args.imagesize), 179 | transforms.RandomCrop(32, padding=4, padding_mode=args.rcrop_pad_mode), 180 | transforms.RandomHorizontalFlip(), 181 | transforms.ToTensor(), 182 | add_noise, 183 | ]) 184 | 185 | transform_test = transforms.Compose([ 186 | transforms.Resize(args.imagesize), 187 | transforms.ToTensor(), 188 | add_noise, 189 | ]) 190 | 191 | # Remove the logit transform. 192 | init_layer = layers.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 193 | else: 194 | if args.real: 195 | transform_train = transforms.Compose([ 196 | transforms.Resize(args.imagesize), 197 | transforms.RandomHorizontalFlip(), 198 | transforms.ToTensor(), 199 | add_noise, 200 | ]) 201 | transform_test = transforms.Compose([ 202 | transforms.Resize(args.imagesize), 203 | transforms.ToTensor(), 204 | add_noise, 205 | ]) 206 | init_layer = layers.LogitTransform(0.05) 207 | if args.real: 208 | train_loader = torch.utils.data.DataLoader( 209 | datasets.CIFAR10(args.dataroot, train=True, transform=transform_train), 210 | batch_size=args.batchsize, 211 | shuffle=True, 212 | num_workers=args.nworkers, 213 | ) 214 | test_loader = torch.utils.data.DataLoader( 215 | datasets.CIFAR10(args.dataroot, train=False, transform=transform_test), 216 | batch_size=args.val_batchsize, 217 | shuffle=False, 218 | num_workers=args.nworkers, 219 | ) 220 | elif args.data == 'mnist': 221 | im_dim = 1 222 | init_layer = layers.LogitTransform(1e-6) 223 | n_classes = 10 224 | 225 | if args.real: 226 | train_loader = torch.utils.data.DataLoader( 227 | datasets.MNIST( 228 | args.dataroot, train=True, transform=transforms.Compose([ 229 | transforms.Resize(args.imagesize), 230 | transforms.ToTensor(), 231 | add_noise, 232 | ]) 233 | ), 234 | batch_size=args.batchsize, 235 | shuffle=True, 236 | num_workers=args.nworkers, 237 | ) 238 | test_loader = torch.utils.data.DataLoader( 239 | datasets.MNIST( 240 | args.dataroot, train=False, transform=transforms.Compose([ 241 | transforms.Resize(args.imagesize), 242 | transforms.ToTensor(), 243 | add_noise, 244 | ]) 245 | ), 246 | batch_size=args.val_batchsize, 247 | shuffle=False, 248 | num_workers=args.nworkers, 249 | ) 250 | elif args.data == 'svhn': 251 | im_dim = 3 252 | init_layer = layers.LogitTransform(0.05) 253 | n_classes = 10 254 | 255 | if args.real: 256 | train_loader = torch.utils.data.DataLoader( 257 | vdsets.SVHN( 258 | args.dataroot, split='train', download=True, transform=transforms.Compose([ 259 | transforms.Resize(args.imagesize), 260 | transforms.RandomCrop(32, padding=4, padding_mode=args.rcrop_pad_mode), 261 | transforms.ToTensor(), 262 | add_noise, 263 | ]) 264 | ), 265 | batch_size=args.batchsize, 266 | shuffle=True, 267 | num_workers=args.nworkers, 268 | ) 269 | test_loader = torch.utils.data.DataLoader( 270 | vdsets.SVHN( 271 | args.dataroot, split='test', download=True, transform=transforms.Compose([ 272 | transforms.Resize(args.imagesize), 273 | transforms.ToTensor(), 274 | add_noise, 275 | ]) 276 | ), 277 | batch_size=args.val_batchsize, 278 | shuffle=False, 279 | num_workers=args.nworkers, 280 | ) 281 | elif args.data == 'celebahq': 282 | im_dim = 3 283 | init_layer = layers.LogitTransform(0.05) 284 | 285 | if args.real: 286 | train_loader = torch.utils.data.DataLoader( 287 | datasets.CelebAHQ( 288 | train=True, transform=transforms.Compose([ 289 | transforms.ToPILImage(), 290 | transforms.RandomHorizontalFlip(), 291 | transforms.ToTensor(), 292 | reduce_bits, 293 | lambda x: add_noise(x, nvals=2**args.nbits), 294 | ]) 295 | ), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers 296 | ) 297 | test_loader = torch.utils.data.DataLoader( 298 | datasets.CelebAHQ( 299 | train=False, transform=transforms.Compose([ 300 | reduce_bits, 301 | lambda x: add_noise(x, nvals=2**args.nbits), 302 | ]) 303 | ), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers 304 | ) 305 | elif args.data == 'celeba_5bit': 306 | im_dim = 3 307 | init_layer = layers.LogitTransform(0.05) 308 | if args.imagesize != 64: 309 | print('Changing image size to 64.') 310 | args.imagesize = 64 311 | 312 | if args.real: 313 | train_loader = torch.utils.data.DataLoader( 314 | datasets.CelebA5bit( 315 | train=True, transform=transforms.Compose([ 316 | transforms.ToPILImage(), 317 | transforms.RandomHorizontalFlip(), 318 | transforms.ToTensor(), 319 | lambda x: add_noise(x, nvals=32), 320 | ]) 321 | ), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers 322 | ) 323 | test_loader = torch.utils.data.DataLoader( 324 | datasets.CelebA5bit(train=False, transform=transforms.Compose([ 325 | lambda x: add_noise(x, nvals=32), 326 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers 327 | ) 328 | elif args.data == 'imagenet32': 329 | im_dim = 3 330 | init_layer = layers.LogitTransform(0.05) 331 | if args.imagesize != 32: 332 | print('Changing image size to 32.') 333 | args.imagesize = 32 334 | 335 | if args.real: 336 | train_loader = torch.utils.data.DataLoader( 337 | datasets.Imagenet32(train=True, transform=transforms.Compose([ 338 | add_noise, 339 | ])), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers 340 | ) 341 | test_loader = torch.utils.data.DataLoader( 342 | datasets.Imagenet32(train=False, transform=transforms.Compose([ 343 | add_noise, 344 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers 345 | ) 346 | elif args.data == 'imagenet64': 347 | im_dim = 3 348 | init_layer = layers.LogitTransform(0.05) 349 | if args.imagesize != 64: 350 | print('Changing image size to 64.') 351 | args.imagesize = 64 352 | 353 | if args.real: 354 | train_loader = torch.utils.data.DataLoader( 355 | datasets.Imagenet64(train=True, transform=transforms.Compose([ 356 | add_noise, 357 | ])), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers 358 | ) 359 | test_loader = torch.utils.data.DataLoader( 360 | datasets.Imagenet64(train=False, transform=transforms.Compose([ 361 | add_noise, 362 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers 363 | ) 364 | 365 | if args.task in ['classification', 'hybrid']: 366 | try: 367 | n_classes 368 | except NameError: 369 | raise ValueError('Cannot perform classification with {}'.format(args.data)) 370 | else: 371 | n_classes = 1 372 | 373 | print('Dataset loaded.', flush=True) 374 | print('Creating model.', flush=True) 375 | 376 | input_size = (args.batchsize, im_dim + args.padding, args.imagesize, args.imagesize) 377 | 378 | if args.squeeze_first: 379 | input_size = (input_size[0], input_size[1] * 4, input_size[2] // 2, input_size[3] // 2) 380 | squeeze_layer = layers.SqueezeLayer(2) 381 | 382 | # Model 383 | model = ResidualFlow( 384 | input_size, 385 | n_blocks=list(map(int, args.nblocks.split('-'))), 386 | intermediate_dim=args.idim, 387 | factor_out=args.factor_out, 388 | quadratic=args.quadratic, 389 | init_layer=init_layer, 390 | actnorm=args.actnorm, 391 | fc_actnorm=args.fc_actnorm, 392 | batchnorm=args.batchnorm, 393 | dropout=args.dropout, 394 | fc=args.fc, 395 | coeff=args.coeff, 396 | vnorms=args.vnorms, 397 | n_lipschitz_iters=args.n_lipschitz_iters, 398 | sn_atol=args.sn_tol, 399 | sn_rtol=args.sn_tol, 400 | n_power_series=args.n_power_series, 401 | n_dist=args.n_dist, 402 | n_samples=args.n_samples, 403 | kernels=args.kernels, 404 | activation_fn=args.act, 405 | fc_end=args.fc_end, 406 | fc_idim=args.fc_idim, 407 | n_exact_terms=args.n_exact_terms, 408 | preact=args.preact, 409 | neumann_grad=args.neumann_grad, 410 | grad_in_forward=args.mem_eff, 411 | first_resblock=args.first_resblock, 412 | learn_p=args.learn_p, 413 | block_type=args.block, 414 | ) 415 | 416 | model.to(device) 417 | 418 | print('Initializing model.', flush=True) 419 | 420 | with torch.no_grad(): 421 | x = torch.rand(1, *input_size[1:]).to(device) 422 | model(x) 423 | print('Restoring from checkpoint.', flush=True) 424 | checkpt = torch.load(args.resume) 425 | state = model.state_dict() 426 | model.load_state_dict(checkpt['state_dict'], strict=True) 427 | 428 | ema = utils.ExponentialMovingAverage(model) 429 | ema.set(checkpt['ema']) 430 | ema.swap() 431 | 432 | print(model, flush=True) 433 | 434 | model.eval() 435 | print('Updating lipschitz.', flush=True) 436 | update_lipschitz(model) 437 | 438 | 439 | def visualize(model): 440 | utils.makedirs('{}_imgs_t{}'.format(args.data, args.temp)) 441 | 442 | with torch.no_grad(): 443 | 444 | for i in tqdm(range(args.nbatches)): 445 | # random samples 446 | rand_z = torch.randn(args.batchsize, (im_dim + args.padding) * args.imagesize * args.imagesize).to(device) 447 | rand_z = rand_z * args.temp 448 | fake_imgs = model(rand_z, inverse=True).view(-1, *input_size[1:]) 449 | if args.squeeze_first: fake_imgs = squeeze_layer.inverse(fake_imgs) 450 | fake_imgs = remove_padding(fake_imgs) 451 | fake_imgs = fake_imgs.view(-1, im_dim, args.imagesize, args.imagesize) 452 | fake_imgs = fake_imgs.cpu() 453 | 454 | if args.save_each: 455 | for j in range(fake_imgs.shape[0]): 456 | save_image( 457 | fake_imgs[j], '{}_imgs_t{}/{}.png'.format(args.data, args.temp, args.batchsize * i + j), nrow=1, 458 | padding=0, range=(0, 1), pad_value=0 459 | ) 460 | else: 461 | save_image( 462 | fake_imgs, 'imgs/{}_t{}_samples{}.png'.format(args.data, args.temp, i), nrow=W, padding=2, 463 | range=(0, 1), pad_value=1 464 | ) 465 | 466 | 467 | real_imgs = test_loader.__iter__().__next__()[0] if args.real else None 468 | if args.real: 469 | real_imgs = test_loader.__iter__().__next__()[0] 470 | save_image( 471 | real_imgs.cpu().float(), 'imgs/{}_real.png'.format(args.data), nrow=W, padding=2, range=(0, 1), pad_value=1 472 | ) 473 | 474 | visualize(model) 475 | -------------------------------------------------------------------------------- /run_cifar10.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train_img.py --data cifar10 --actnorm True \ 2 | --nblocks '2-2-2' --idim '512' --act 'swish' --kernels '3-1-3' --vnorms '2222' --fc-end False --preact True \ 3 | --save 'experiments/cifar10(blocks_2*3(512,k313)_swish_nofc_preact_10term' --coeff 0.9 --n-exact-terms 10 4 | -------------------------------------------------------------------------------- /run_classification.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train_classification.py --data 'cifar100' --model-dir 'experiments/classify_cifar100_Resnet18_c0.9' \ 2 | --weight-decay 0 --epochs 150 --log-interval 20 --batch-size 128 --test-batch-size 128 --lr 0.001 --coeff 0.9 -------------------------------------------------------------------------------- /run_tabular.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train_tabular.py --nblocks 20 --vnorms '222222' --dims '128-128-128-128' \ 2 | --save 'experiments/tabular_(power_block20,128*4,c99,sin)_bf' --act 'sin' --data 'power' --batchsize 1000 --coeff 0.99 --nepochs 10000 --epsf 1e-5 3 | -------------------------------------------------------------------------------- /run_toy.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train_toy.py --nblocks 6 --vnorms '2222' --dims '128-128' \ 2 | --arch 'implicit' --brute-force True --save 'experiments/res_toy(block6,128*2,c99,sin,5000)' --act 'sin' --data 'checkerboard' --batch_size 5000 --coeff 0.99 --n-lipschitz-iters 20 -------------------------------------------------------------------------------- /train_classification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | import lib.utils as utils 7 | 8 | import os 9 | import argparse 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torchvision 14 | import torch.optim as optim 15 | from torchvision import datasets, transforms 16 | import time 17 | import lib.layers.base as base_layers 18 | import lib.layers as layers 19 | 20 | parser = argparse.ArgumentParser(description='PyTorch CIFAR TRADES Adversarial Training') 21 | parser.add_argument( 22 | '--data', type=str, default='cifar10', choices=[ 23 | 'cifar10', 24 | 'cifar100', 25 | 'mnist', 26 | ] 27 | ) 28 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 29 | help='input batch size for training (default: 128)') 30 | parser.add_argument('--test-batch-size', type=int, default=128, metavar='N', 31 | help='input batch size for testing (default: 128)') 32 | parser.add_argument('--epochs', type=int, default=76, metavar='N', 33 | help='number of epochs to train') 34 | parser.add_argument('--weight-decay', '--wd', default=2e-4, 35 | type=float, metavar='W') 36 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 37 | help='learning rate') 38 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 39 | help='SGD momentum') 40 | parser.add_argument('--no-cuda', action='store_true', default=False, 41 | help='disables CUDA training') 42 | parser.add_argument('--epsilon', default=0.031, 43 | help='perturbation') 44 | parser.add_argument('--num-steps', default=10, 45 | help='perturb number of steps') 46 | parser.add_argument('--step-size', default=0.007, 47 | help='perturb step size') 48 | parser.add_argument('--beta', default=6.0, 49 | help='regularization, i.e., 1/lambda in TRADES') 50 | parser.add_argument('--seed', type=int, default=1, metavar='S', 51 | help='random seed (default: 1)') 52 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 53 | help='how many batches to wait before logging training status') 54 | parser.add_argument('--model-dir', default='./experiments/model-cifar-Resnet18', 55 | help='directory of model for saving checkpoint') 56 | parser.add_argument('--save-freq', '-s', default=50, type=int, metavar='N', 57 | help='save frequency') 58 | parser.add_argument('--coeff', type=float, default=0.99) 59 | 60 | args = parser.parse_args() 61 | 62 | # settings 63 | model_dir = args.model_dir 64 | utils.makedirs(model_dir) 65 | logger = utils.get_logger(logpath=os.path.join(model_dir, 'logs'), filepath=os.path.abspath(__file__)) 66 | logger.info(args) 67 | 68 | use_cuda = not args.no_cuda and torch.cuda.is_available() 69 | if use_cuda: 70 | torch.backends.cudnn.benchmark = True 71 | np.random.seed(args.seed) 72 | torch.manual_seed(args.seed) 73 | torch.cuda.manual_seed(args.seed) 74 | device = torch.device("cuda" if use_cuda else "cpu") 75 | kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {} 76 | 77 | 78 | ACTIVATION_FNS = { 79 | 'identity': base_layers.Identity, 80 | 'relu': torch.nn.ReLU, 81 | 'tanh': torch.nn.Tanh, 82 | 'elu': torch.nn.ELU, 83 | 'selu': torch.nn.SELU, 84 | 'fullsort': base_layers.FullSort, 85 | 'maxmin': base_layers.MaxMin, 86 | 'swish': base_layers.Swish, 87 | 'lcube': base_layers.LipschitzCube, 88 | 'sin': base_layers.Sin, 89 | 'zero': base_layers.Zero, 90 | } 91 | 92 | class BasicBlock(nn.Module): 93 | expansion = 1 94 | 95 | def __init__(self, in_planes, hidden, planes, stride=1): 96 | super(BasicBlock, self).__init__() 97 | 98 | def build_net(): 99 | layer = nn.Conv2d 100 | nnet = [] 101 | nnet.append( 102 | layer( 103 | in_planes, hidden, kernel_size=3, stride=1, padding=1, bias=False 104 | ) 105 | ) 106 | nnet.append(nn.BatchNorm2d(hidden)) 107 | nnet.append(ACTIVATION_FNS['relu']()) 108 | nnet.append( 109 | layer( 110 | hidden, in_planes, kernel_size=3, stride=1, padding=1, bias=False 111 | ) 112 | ) 113 | nnet.append(nn.BatchNorm2d(in_planes)) 114 | nnet.append(ACTIVATION_FNS['relu']()) 115 | return nn.Sequential(*nnet) 116 | 117 | self.block1 = build_net() 118 | self.block2 = build_net() 119 | 120 | self.downsample = nn.Sequential() 121 | if stride != 1 or in_planes != self.expansion * planes: 122 | self.downsample = nn.Sequential( 123 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 124 | nn.BatchNorm2d(self.expansion * planes), 125 | ACTIVATION_FNS['relu'](), 126 | ) 127 | 128 | def forward(self, x): 129 | out = F.relu(x + self.block1(x)) 130 | out = out + self.block2(out) 131 | out = self.downsample(out) 132 | return out 133 | 134 | 135 | class BasicImplicitBlock(nn.Module): 136 | expansion = 1 137 | 138 | def __init__( 139 | self, 140 | in_planes, 141 | hidden, 142 | planes, 143 | stride=1, 144 | n_lipschitz_iters=None, 145 | sn_atol=1e-3, 146 | sn_rtol=1e-3, 147 | ): 148 | super(BasicImplicitBlock, self).__init__() 149 | coeff = args.coeff 150 | self.initialized = False 151 | 152 | def build_net(): 153 | layer = base_layers.get_conv2d 154 | nnet = [] 155 | nnet.append( 156 | layer( 157 | in_planes, hidden, kernel_size=3, stride=1, padding=1, bias=False, coeff=coeff, n_iterations=n_lipschitz_iters, domain=2, codomain=2, atol=sn_atol, rtol=sn_rtol, 158 | ) 159 | ) 160 | nnet.append(ACTIVATION_FNS['relu']()) 161 | nnet.append( 162 | layer( 163 | hidden, in_planes, kernel_size=3, stride=1, padding=1, bias=False, coeff=coeff, n_iterations=n_lipschitz_iters, domain=2, codomain=2, atol=sn_atol, rtol=sn_rtol, 164 | ) 165 | ) 166 | nnet.append(ACTIVATION_FNS['relu']()) 167 | return nn.Sequential(*nnet) 168 | 169 | self.block = layers.imBlock( 170 | build_net(), 171 | build_net(), 172 | ) 173 | self.downsample = nn.Sequential() 174 | if stride != 1 or in_planes != self.expansion * planes: 175 | self.downsample = nn.Sequential( 176 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 177 | nn.BatchNorm2d(self.expansion * planes), 178 | ACTIVATION_FNS['relu'](), 179 | ) 180 | 181 | def forward(self, x): 182 | if self.initialized: 183 | out = self.block(x) 184 | else: 185 | out = self.block(x, restore=True) 186 | self.initialized = True 187 | out = self.downsample(out) 188 | return out 189 | 190 | 191 | class Bottleneck(nn.Module): 192 | expansion = 4 193 | 194 | def __init__(self, in_planes, planes, stride=1): 195 | super(Bottleneck, self).__init__() 196 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 197 | self.bn1 = nn.BatchNorm2d(planes) 198 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 199 | self.bn2 = nn.BatchNorm2d(planes) 200 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 201 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 202 | 203 | self.shortcut = nn.Sequential() 204 | if stride != 1 or in_planes != self.expansion * planes: 205 | self.shortcut = nn.Sequential( 206 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 207 | nn.BatchNorm2d(self.expansion * planes) 208 | ) 209 | 210 | def forward(self, x): 211 | out = F.relu(self.bn1(self.conv1(x))) 212 | out = F.relu(self.bn2(self.conv2(out))) 213 | out = self.bn3(self.conv3(out)) 214 | out += self.shortcut(x) 215 | out = F.relu(out) 216 | return out 217 | 218 | 219 | class ResNet(nn.Module): 220 | def __init__(self, block, num_blocks, num_classes=10): 221 | super(ResNet, self).__init__() 222 | self.in_planes = 64 223 | 224 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 225 | self.bn1 = nn.BatchNorm2d(64) 226 | self.layer1 = self._make_layer(block, 64, 64, num_blocks[0], stride=1) 227 | self.layer2 = self._make_layer(block, 128, 128, num_blocks[1], stride=2) 228 | self.layer3 = self._make_layer(block, 256, 256, num_blocks[2], stride=2) 229 | self.layer4 = self._make_layer(block, 512, 512, num_blocks[3], stride=2) 230 | self.linear = nn.Linear(512 * block.expansion, num_classes) 231 | 232 | def _make_layer(self, block, hidden, planes, num_blocks, stride): 233 | strides = [stride] + [1] * (num_blocks - 1) 234 | layers = [] 235 | for stride in strides: 236 | layers.append(block(self.in_planes, hidden, planes, stride)) 237 | self.in_planes = planes * block.expansion 238 | return nn.Sequential(*layers) 239 | 240 | def forward(self, x): 241 | out = F.relu(self.bn1(self.conv1(x))) 242 | out = self.layer1(out) 243 | out = self.layer2(out) 244 | out = self.layer3(out) 245 | out = self.layer4(out) 246 | out = F.avg_pool2d(out, 4) 247 | out = out.view(out.size(0), -1) 248 | out = self.linear(out) 249 | return out 250 | 251 | 252 | class ImplicitResNet(nn.Module): 253 | def __init__(self, block, num_blocks, num_classes=10): 254 | super(ImplicitResNet, self).__init__() 255 | self.in_planes = 64 256 | 257 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 258 | self.bn1 = nn.BatchNorm2d(64) 259 | self.layer1 = self._make_layer(block, 64, 64, num_blocks[0], stride=1) 260 | self.layer2 = self._make_layer(block, 128, 128, num_blocks[1], stride=2) 261 | self.layer3 = self._make_layer(block, 256, 256, num_blocks[2], stride=2) 262 | self.layer4 = self._make_layer(block, 512, 512, num_blocks[3], stride=2) 263 | self.linear = nn.Linear(512 * block.expansion, num_classes) 264 | 265 | def _make_layer(self, block, hidden, planes, num_blocks, stride): 266 | strides = [stride] + [1] * (num_blocks - 1) 267 | layers = [] 268 | for stride in strides: 269 | layers.append(block(self.in_planes, hidden, planes, stride)) 270 | self.in_planes = planes * block.expansion 271 | return nn.Sequential(*layers) 272 | 273 | def forward(self, x): 274 | out = F.relu(self.bn1(self.conv1(x))) 275 | out = self.layer1(out) 276 | out = self.layer2(out) 277 | out = self.layer3(out) 278 | out = self.layer4(out) 279 | out = F.avg_pool2d(out, 4) 280 | out = out.view(out.size(0), -1) 281 | out = self.linear(out) 282 | return out 283 | 284 | 285 | def ResNet18(num_classes=10): 286 | return ResNet(BasicBlock, [1, 1, 1, 1], num_classes=num_classes) 287 | 288 | def ImplicitResNet18(num_classes=10): 289 | return ImplicitResNet(BasicImplicitBlock, [1, 1, 1, 1], num_classes=num_classes) 290 | 291 | def ResNet34(): 292 | return ResNet(BasicBlock, [3, 4, 6, 3]) 293 | 294 | 295 | def ResNet50(): 296 | return ResNet(Bottleneck, [3, 4, 6, 3]) 297 | 298 | 299 | def ResNet101(): 300 | return ResNet(Bottleneck, [3, 4, 23, 3]) 301 | 302 | 303 | def ResNet152(): 304 | return ResNet(Bottleneck, [3, 8, 36, 3]) 305 | 306 | 307 | # setup data loader 308 | transform_train = [ 309 | transforms.RandomCrop(32, padding=4), 310 | transforms.RandomHorizontalFlip(), 311 | transforms.ToTensor(), 312 | ] 313 | transform_test = [ 314 | transforms.ToTensor(), 315 | ] 316 | 317 | if args.data == 'cifar10': 318 | trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transforms.Compose(transform_train)) 319 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs) 320 | testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transforms.Compose(transform_test)) 321 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, **kwargs) 322 | args.num_classes = 10 323 | elif args.data == 'cifar100': 324 | trainset = torchvision.datasets.CIFAR100(root='data', train=True, download=True, transform=transforms.Compose(transform_train)) 325 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs) 326 | testset = torchvision.datasets.CIFAR100(root='data', train=False, download=True, transform=transforms.Compose(transform_test)) 327 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, **kwargs) 328 | args.num_classes = 100 329 | elif args.data == 'mnist': 330 | transform_mnist = [transforms.Pad(2, 0),] 331 | transform_mnist2 = [lambda x: x.repeat((3, 1, 1))] 332 | trainset = torchvision.datasets.MNIST(root='data', train=True, download=True, transform=transforms.Compose(transform_mnist + transform_train + transform_mnist2)) 333 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs) 334 | testset = torchvision.datasets.MNIST(root='data', train=False, download=True, transform=transforms.Compose(transform_mnist + transform_test + transform_mnist2)) 335 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, **kwargs) 336 | args.num_classes = 10 337 | 338 | batch_time = utils.RunningAverageMeter(0.97) 339 | loss_meter = utils.RunningAverageMeter(0.97) 340 | 341 | 342 | def update_lipschitz(model): 343 | with torch.no_grad(): 344 | for m in model.modules(): 345 | if isinstance(m, base_layers.SpectralNormConv2d) or isinstance(m, base_layers.SpectralNormLinear): 346 | m.compute_weight(update=True) 347 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear): 348 | m.compute_weight(update=True) 349 | 350 | 351 | def train(args, model, device, train_loader, optimizer, epoch, ema): 352 | model.train() 353 | end = time.time() 354 | for batch_idx, (data, target) in enumerate(train_loader): 355 | data, target = data.to(device), target.to(device) 356 | output = model(data) 357 | loss = F.cross_entropy(output, target, size_average=False) 358 | loss_meter.update(loss.item()) 359 | 360 | loss.backward() 361 | optimizer.step() 362 | optimizer.zero_grad() 363 | update_lipschitz(model) 364 | ema.apply() 365 | 366 | batch_time.update(time.time() - end) 367 | end = time.time() 368 | 369 | # print progress 370 | if batch_idx % args.log_interval == 0: 371 | logger.info('Train Epoch: {} [{}/{} ({:.0f}%)] | Time {batch_time.val:.3f} | Loss: {loss_meter.val:.6f}'.format( 372 | epoch, batch_idx * len(data), len(train_loader.dataset), 373 | 100. * batch_idx / len(train_loader), batch_time=batch_time, loss_meter=loss_meter)) 374 | 375 | 376 | def eval_train(model, device, train_loader, ema): 377 | ema.swap() 378 | update_lipschitz(model) 379 | model.eval() 380 | train_loss = 0 381 | correct = 0 382 | with torch.no_grad(): 383 | for data, target in train_loader: 384 | data, target = data.to(device), target.to(device) 385 | output = model(data) 386 | train_loss += F.cross_entropy(output, target, size_average=False).item() 387 | pred = output.max(1, keepdim=True)[1] 388 | correct += pred.eq(target.view_as(pred)).sum().item() 389 | train_loss /= len(train_loader.dataset) 390 | logger.info('Training: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format( 391 | train_loss, correct, len(train_loader.dataset), 392 | 100. * correct / len(train_loader.dataset))) 393 | training_accuracy = correct / len(train_loader.dataset) 394 | ema.swap() 395 | return train_loss, training_accuracy 396 | 397 | 398 | def eval_test(model, device, test_loader, ema): 399 | ema.swap() 400 | model.eval() 401 | test_loss = 0 402 | correct = 0 403 | with torch.no_grad(): 404 | for data, target in test_loader: 405 | data, target = data.to(device), target.to(device) 406 | output = model(data) 407 | test_loss += F.cross_entropy(output, target, size_average=False).item() 408 | pred = output.max(1, keepdim=True)[1] 409 | correct += pred.eq(target.view_as(pred)).sum().item() 410 | test_loss /= len(test_loader.dataset) 411 | logger.info('Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format( 412 | test_loss, correct, len(test_loader.dataset), 413 | 100. * correct / len(test_loader.dataset))) 414 | test_accuracy = correct / len(test_loader.dataset) 415 | ema.swap() 416 | return test_loss, test_accuracy 417 | 418 | 419 | def adjust_learning_rate(optimizer, epoch): 420 | """decrease the learning rate""" 421 | lr = args.lr 422 | if epoch >= 75: 423 | lr = args.lr * 0.1 424 | if epoch >= 90: 425 | lr = args.lr * 0.01 426 | if epoch >= 100: 427 | lr = args.lr * 0.001 428 | for param_group in optimizer.param_groups: 429 | param_group['lr'] = lr 430 | 431 | 432 | def main(): 433 | # model = torch.nn.DataParallel(ResNet18(num_classes=args.num_classes).to(device)) 434 | model = torch.nn.DataParallel(ImplicitResNet18(num_classes=args.num_classes).to(device)) 435 | with torch.no_grad(): 436 | x, _ = next(iter(train_loader)) 437 | x = x.to(device) 438 | model(x) 439 | ema = utils.ExponentialMovingAverage(model) 440 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay) 441 | # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 442 | logger.info(model) 443 | logger.info('EMA: {}'.format(ema)) 444 | logger.info(optimizer) 445 | 446 | for epoch in range(1, args.epochs + 1): 447 | # adjust learning rate for SGD 448 | adjust_learning_rate(optimizer, epoch) 449 | 450 | # adversarial training 451 | train(args, model, device, train_loader, optimizer, epoch, ema) 452 | 453 | # evaluation on natural examples 454 | logger.info('================================================================') 455 | eval_train(model, device, train_loader, ema) 456 | eval_test(model, device, test_loader, ema) 457 | logger.info('================================================================') 458 | 459 | # save checkpoint 460 | if epoch == args.epochs: 461 | torch.save(model.state_dict(), 462 | os.path.join(model_dir, 'model-wideres-epoch{}.pt'.format(epoch))) 463 | torch.save(optimizer.state_dict(), 464 | os.path.join(model_dir, 'opt-wideres-checkpoint_epoch{}.tar'.format(epoch))) 465 | 466 | 467 | if __name__ == '__main__': 468 | main() -------------------------------------------------------------------------------- /train_toy.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | 5 | import argparse 6 | import os 7 | import time 8 | import math 9 | import numpy as np 10 | 11 | import torch 12 | 13 | import lib.optimizers as optim 14 | import lib.layers.base as base_layers 15 | import lib.layers as layers 16 | import lib.toy_data as toy_data 17 | import lib.utils as utils 18 | from lib.visualize_flow import visualize_transform 19 | 20 | 21 | ACTIVATION_FNS = { 22 | 'identity': base_layers.Identity, 23 | 'relu': torch.nn.ReLU, 24 | 'tanh': torch.nn.Tanh, 25 | 'elu': torch.nn.ELU, 26 | 'selu': torch.nn.SELU, 27 | 'fullsort': base_layers.FullSort, 28 | 'maxmin': base_layers.MaxMin, 29 | 'swish': base_layers.Swish, 30 | 'lcube': base_layers.LipschitzCube, 31 | 'sin': base_layers.Sin, 32 | } 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | '--data', choices=['swissroll', '8gaussians', 'pinwheel', 'circles', 'moons', '2spirals', 'checkerboard', 'rings'], 37 | type=str, default='pinwheel' 38 | ) 39 | parser.add_argument('--arch', choices=['iresnet', 'realnvp', 'implicit'], default='implicit') 40 | parser.add_argument('--coeff', type=float, default=0.9) 41 | parser.add_argument('--vnorms', type=str, default='222222') 42 | parser.add_argument('--n-lipschitz-iters', type=int, default=5) 43 | parser.add_argument('--atol', type=float, default=None) 44 | parser.add_argument('--rtol', type=float, default=None) 45 | parser.add_argument('--learn-p', type=eval, choices=[True, False], default=False) 46 | parser.add_argument('--mixed', type=eval, choices=[True, False], default=True) 47 | 48 | parser.add_argument('--dims', type=str, default='128-128-128-128') 49 | parser.add_argument('--act', type=str, choices=ACTIVATION_FNS.keys(), default='sin') 50 | parser.add_argument('--nblocks', type=int, default=100) 51 | parser.add_argument('--brute-force', type=eval, choices=[True, False], default=False) 52 | parser.add_argument('--actnorm', type=eval, choices=[True, False], default=False) 53 | parser.add_argument('--batchnorm', type=eval, choices=[True, False], default=False) 54 | parser.add_argument('--exact-trace', type=eval, choices=[True, False], default=False) 55 | parser.add_argument('--n-power-series', type=int, default=None) 56 | parser.add_argument('--n-samples', type=int, default=1) 57 | parser.add_argument('--n-dist', choices=['geometric', 'poisson'], default='geometric') 58 | 59 | parser.add_argument('--niters', type=int, default=50000) 60 | parser.add_argument('--batch_size', type=int, default=1000) 61 | parser.add_argument('--test_batch_size', type=int, default=10000) 62 | parser.add_argument('--lr', type=float, default=1e-3) 63 | parser.add_argument('--weight-decay', type=float, default=1e-5) 64 | parser.add_argument('--annealing-iters', type=int, default=0) 65 | 66 | parser.add_argument('--resume', type=str, default=None) 67 | parser.add_argument('--begin-epoch', type=int, default=0) 68 | 69 | parser.add_argument('--save', type=str, default='experiments/iresnet_toy') 70 | parser.add_argument('--viz_freq', type=int, default=1000) 71 | parser.add_argument('--val_freq', type=int, default=1000) 72 | parser.add_argument('--log_freq', type=int, default=100) 73 | parser.add_argument('--gpu', type=int, default=0) 74 | parser.add_argument('--seed', type=int, default=0) 75 | args = parser.parse_args() 76 | 77 | # logger 78 | utils.makedirs(args.save) 79 | logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) 80 | logger.info(args) 81 | 82 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') 83 | 84 | np.random.seed(args.seed) 85 | torch.manual_seed(args.seed) 86 | if device.type == 'cuda': 87 | torch.cuda.manual_seed(args.seed) 88 | 89 | 90 | def count_parameters(model): 91 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 92 | 93 | 94 | def standard_normal_sample(size): 95 | return torch.randn(size) 96 | 97 | 98 | def standard_normal_logprob(z): 99 | logZ = -0.5 * np.log(2 * np.pi) 100 | return logZ - z.pow(2) / 2 101 | 102 | 103 | def compute_loss(args, model, batch_size=None, beta=1.): 104 | if batch_size is None: batch_size = args.batch_size 105 | 106 | # load data 107 | x = toy_data.inf_train_gen(args.data, batch_size=batch_size) 108 | x = torch.from_numpy(x).type(torch.float32).to(device) 109 | zero = torch.zeros(x.shape[0], 1).to(x) 110 | 111 | # transform to z 112 | z, delta_logp = model(x, zero) 113 | 114 | # compute log p(z) 115 | logpz = standard_normal_logprob(z).sum(1, keepdim=True) 116 | 117 | logpx = logpz - beta * delta_logp 118 | loss = -torch.mean(logpx) 119 | return loss, torch.mean(logpz), torch.mean(-delta_logp) 120 | 121 | 122 | def parse_vnorms(): 123 | ps = [] 124 | for p in args.vnorms: 125 | if p == 'f': 126 | ps.append(float('inf')) 127 | else: 128 | ps.append(float(p)) 129 | return ps[:-1], ps[1:] 130 | 131 | 132 | def compute_p_grads(model): 133 | scales = 0. 134 | nlayers = 0 135 | for m in model.modules(): 136 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear): 137 | scales = scales + m.compute_one_iter() 138 | nlayers += 1 139 | scales.mul(1 / nlayers).mul(0.01).backward() 140 | for m in model.modules(): 141 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear): 142 | if m.domain.grad is not None and torch.isnan(m.domain.grad): 143 | m.domain.grad = None 144 | 145 | 146 | def build_nnet(dims, activation_fn=torch.nn.ReLU): 147 | nnet = [] 148 | domains, codomains = parse_vnorms() 149 | if args.learn_p: 150 | if args.mixed: 151 | domains = [torch.nn.Parameter(torch.tensor(0.)) for _ in domains] 152 | else: 153 | domains = [torch.nn.Parameter(torch.tensor(0.))] * len(domains) 154 | codomains = domains[1:] + [domains[0]] 155 | for i, (in_dim, out_dim, domain, codomain) in enumerate(zip(dims[:-1], dims[1:], domains, codomains)): 156 | if i > 0: 157 | nnet.append(activation_fn()) 158 | nnet.append( 159 | base_layers.get_linear( 160 | in_dim, 161 | out_dim, 162 | coeff=args.coeff, 163 | n_iterations=args.n_lipschitz_iters, 164 | atol=args.atol, 165 | rtol=args.rtol, 166 | domain=domain, 167 | codomain=codomain, 168 | zero_init=(out_dim == 2), 169 | ) 170 | ) 171 | return torch.nn.Sequential(*nnet) 172 | 173 | 174 | def update_lipschitz(model, n_iterations): 175 | for m in model.modules(): 176 | if isinstance(m, base_layers.SpectralNormConv2d) or isinstance(m, base_layers.SpectralNormLinear): 177 | m.compute_weight(update=True, n_iterations=n_iterations) 178 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear): 179 | m.compute_weight(update=True, n_iterations=n_iterations) 180 | 181 | 182 | def get_ords(model): 183 | ords = [] 184 | for m in model.modules(): 185 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear): 186 | domain, codomain = m.compute_domain_codomain() 187 | if torch.is_tensor(domain): 188 | domain = domain.item() 189 | if torch.is_tensor(codomain): 190 | codomain = codomain.item() 191 | ords.append(domain) 192 | ords.append(codomain) 193 | return ords 194 | 195 | 196 | def pretty_repr(a): 197 | return '[[' + ','.join(list(map(lambda i: f'{i:.2f}', a))) + ']]' 198 | 199 | 200 | if __name__ == '__main__': 201 | 202 | activation_fn = ACTIVATION_FNS[args.act] 203 | 204 | if args.arch == 'iresnet': 205 | dims = [2] + list(map(int, args.dims.split('-'))) + [2] 206 | blocks = [] 207 | if args.actnorm: blocks.append(layers.ActNorm1d(2)) 208 | for _ in range(args.nblocks): 209 | blocks.append( 210 | layers.iResBlock( 211 | build_nnet(dims, activation_fn), 212 | n_dist=args.n_dist, 213 | n_power_series=args.n_power_series, 214 | exact_trace=args.exact_trace, 215 | brute_force=args.brute_force, 216 | n_samples=args.n_samples, 217 | neumann_grad=False, 218 | grad_in_forward=False, 219 | ) 220 | ) 221 | if args.actnorm: blocks.append(layers.ActNorm1d(2)) 222 | if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2)) 223 | model = layers.SequentialFlow(blocks).to(device) 224 | elif args.arch == 'implicit': 225 | dims = [2] + list(map(int, args.dims.split('-'))) + [2] 226 | blocks = [] 227 | if args.actnorm: blocks.append(layers.ActNorm1d(2)) 228 | for _ in range(args.nblocks): 229 | blocks.append( 230 | layers.imBlock( 231 | build_nnet(dims, activation_fn), 232 | build_nnet(dims, activation_fn), 233 | n_dist=args.n_dist, 234 | n_power_series=args.n_power_series, 235 | exact_trace=args.exact_trace, 236 | brute_force=args.brute_force, 237 | n_samples=args.n_samples, 238 | neumann_grad=False, 239 | grad_in_forward=False, # toy data needn't save memory 240 | ) 241 | ) 242 | model = torch.nn.DataParallel(layers.SequentialFlow(blocks).to(device)) 243 | elif args.arch == 'realnvp': 244 | blocks = [] 245 | for _ in range(args.nblocks): 246 | blocks.append(layers.CouplingBlock(2, swap=False)) 247 | blocks.append(layers.CouplingBlock(2, swap=True)) 248 | if args.actnorm: blocks.append(layers.ActNorm1d(2)) 249 | if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2)) 250 | model = layers.SequentialFlow(blocks).to(device) 251 | 252 | logger.info(model) 253 | logger.info("Number of trainable parameters: {}".format(count_parameters(model))) 254 | 255 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 256 | 257 | time_meter = utils.RunningAverageMeter(0.93) 258 | loss_meter = utils.RunningAverageMeter(0.93) 259 | logpz_meter = utils.RunningAverageMeter(0.93) 260 | delta_logp_meter = utils.RunningAverageMeter(0.93) 261 | 262 | end = time.time() 263 | 264 | if (args.resume is not None): 265 | logger.info('Resuming model from {}'.format(args.resume)) 266 | with torch.no_grad(): 267 | x = toy_data.inf_train_gen(args.data, batch_size=args.batch_size) 268 | x = torch.from_numpy(x).type(torch.float32).to(device) 269 | model(x, restore=True) 270 | checkpt = torch.load(args.resume) 271 | sd = {k: v for k, v in checkpt['state_dict'].items() if 'last_n_samples' not in k} 272 | state = model.state_dict() 273 | state.update(sd) 274 | model.load_state_dict(state, strict=True) 275 | del checkpt 276 | del state 277 | else: 278 | with torch.no_grad(): 279 | x = toy_data.inf_train_gen(args.data, batch_size=args.batch_size) 280 | x = torch.from_numpy(x).type(torch.float32).to(device) 281 | model(x, restore=True) 282 | 283 | best_loss = float('inf') 284 | model.train() 285 | for itr in range(1, args.niters + 1): 286 | optimizer.zero_grad() 287 | 288 | beta = min(1, itr / args.annealing_iters) if args.annealing_iters > 0 else 1. 289 | loss, logpz, delta_logp = compute_loss(args, model, beta=beta) 290 | loss_meter.update(loss.item()) 291 | logpz_meter.update(logpz.item()) 292 | delta_logp_meter.update(delta_logp.item()) 293 | loss.backward() 294 | if args.learn_p and itr > args.annealing_iters: compute_p_grads(model) 295 | optimizer.step() 296 | update_lipschitz(model, args.n_lipschitz_iters) 297 | 298 | time_meter.update(time.time() - end) 299 | 300 | if itr % args.log_freq == 0: 301 | logger.info( 302 | 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f})' 303 | ' | Logp(z) {:.6f}({:.6f}) | DeltaLogp {:.6f}({:.6f})'.format( 304 | itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, logpz_meter.val, logpz_meter.avg, 305 | delta_logp_meter.val, delta_logp_meter.avg 306 | ) 307 | ) 308 | 309 | if itr % args.val_freq == 0 or itr == args.niters: 310 | update_lipschitz(model, 200) 311 | with torch.no_grad(): 312 | model.eval() 313 | test_loss, test_logpz, test_delta_logp = compute_loss(args, model, batch_size=args.test_batch_size) 314 | log_message = ( 315 | '[TEST] Iter {:04d} | Test Loss {:.6f} ' 316 | '| Test Logp(z) {:.6f} | Test DeltaLogp {:.6f}'.format( 317 | itr, test_loss.item(), test_logpz.item(), test_delta_logp.item() 318 | ) 319 | ) 320 | logger.info(log_message) 321 | 322 | logger.info('Ords: {}'.format(pretty_repr(get_ords(model)))) 323 | 324 | if test_loss.item() < best_loss: 325 | best_loss = test_loss.item() 326 | utils.makedirs(args.save) 327 | torch.save({ 328 | 'args': args, 329 | 'state_dict': model.state_dict(), 330 | }, os.path.join(args.save, 'checkpt.pth')) 331 | model.train() 332 | 333 | if itr == 1 or itr % args.viz_freq == 0: 334 | with torch.no_grad(): 335 | model.eval() 336 | p_samples = toy_data.inf_train_gen(args.data, batch_size=20000) 337 | 338 | sample_fn, density_fn = model.module.inverse, model.forward 339 | 340 | plt.figure(figsize=(9, 3)) 341 | visualize_transform( 342 | p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn, 343 | samples=True, npts=400, device=device 344 | ) 345 | fig_filename = os.path.join(args.save, 'figs', '{:04d}.jpg'.format(itr)) 346 | utils.makedirs(os.path.dirname(fig_filename)) 347 | plt.savefig(fig_filename) 348 | plt.close() 349 | model.train() 350 | 351 | end = time.time() 352 | 353 | logger.info('Training has finished.') 354 | --------------------------------------------------------------------------------