├── .gitignore ├── LICENSE ├── README.md ├── assets ├── celebahq_resflow.jpg └── flow_comparison.jpg ├── preprocessing ├── convert_to_pth.py ├── create_imagenet_benchmark_datasets.py └── extract_celeba_from_tfrecords.py ├── qualitative_samples.py ├── resflows ├── datasets.py ├── layers │ ├── __init__.py │ ├── act_norm.py │ ├── base │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── lipschitz.py │ │ ├── mixed_lipschitz.py │ │ └── utils.py │ ├── container.py │ ├── coupling.py │ ├── elemwise.py │ ├── glow.py │ ├── iresblock.py │ ├── mask_utils.py │ ├── normalization.py │ └── squeeze.py ├── lr_scheduler.py ├── optimizers.py ├── resflow.py ├── toy_data.py ├── utils.py └── visualize_flow.py ├── setup.py ├── train_img.py └── train_toy.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *__pycache__* 3 | data/* 4 | pretrained_models 5 | *egg-info 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ricky Tian Qi Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Residual Flows for Invertible Generative Modeling [[arxiv](https://arxiv.org/abs/1906.02735)] 2 | 3 |

4 | 5 |

6 | 7 | Building on the use of [Invertible Residual Networks](https://arxiv.org/abs/1811.00995) in generative modeling, we propose: 8 | + Unbiased estimation of the log-density of samples. 9 | + Memory-efficient reformulation of the gradients. 10 | + LipSwish activation function. 11 | 12 | As a result, Residual Flows scale to much larger networks and datasets. 13 | 14 |

15 | 16 |

17 | 18 | ## Requirements 19 | 20 | - PyTorch 1.0+ 21 | - Python 3.6+ 22 | 23 | ## Package installation (optional) 24 | 25 | If you want to just use the `resflows` package, you can install it through pip: 26 | ``` 27 | pip install git+https://github.com/rtqichen/residual-flows 28 | ``` 29 | 30 | ## Preprocessing 31 | ImageNet: 32 | 1. Follow instructions in `preprocessing/create_imagenet_benchmark_datasets`. 33 | 2. Convert .npy files to .pth using `preprocessing/convert_to_pth`. 34 | 3. Place in `data/imagenet32` and `data/imagenet64`. 35 | 36 | CelebAHQ 64x64 5bit: 37 | 38 | 1. Download from https://github.com/aravindsrinivas/flowpp/tree/master/flows_celeba. 39 | 2. Convert .npy files to .pth using `preprocessing/convert_to_pth`. 40 | 3. Place in `data/celebahq64_5bit`. 41 | 42 | CelebAHQ 256x256: 43 | ``` 44 | # Download Glow's preprocessed dataset. 45 | wget https://storage.googleapis.com/glow-demo/data/celeba-tfr.tar 46 | tar -C data/celebahq -xvf celeb-tfr.tar 47 | python extract_celeba_from_tfrecords 48 | ``` 49 | 50 | ## Density Estimation Experiments 51 | 52 | ***NOTE***: By default, O(1)-memory gradients are enabled. However, the logged bits/dim during training will not be an actual estimate of bits/dim but whatever scalar was used to generate the unbiased gradients. If you want to check the actual bits/dim for training (and have sufficient GPU memory), set `--neumann-grad=False`. Note however that the memory cost can stochastically vary during training if this flag is `False`. 53 | 54 | MNIST: 55 | ``` 56 | python train_img.py --data mnist --imagesize 28 --actnorm True --wd 0 --save experiments/mnist 57 | ``` 58 | 59 | CIFAR10: 60 | ``` 61 | python train_img.py --data cifar10 --actnorm True --save experiments/cifar10 62 | ``` 63 | 64 | ImageNet 32x32: 65 | ``` 66 | python train_img.py --data imagenet32 --actnorm True --nblocks 32-32-32 --save experiments/imagenet32 67 | ``` 68 | 69 | ImageNet 64x64: 70 | ``` 71 | python train_img.py --data imagenet64 --imagesize 64 --actnorm True --nblocks 32-32-32 --factor-out True --squeeze-first True --save experiments/imagenet64 72 | ``` 73 | 74 | CelebAHQ 256x256: 75 | ``` 76 | python train_img.py --data celebahq --imagesize 256 --nbits 5 --actnorm True --act elu --batchsize 8 --update-freq 5 --n-exact-terms 8 --fc-end False --factor-out True --squeeze-first True --nblocks 16-16-16-16-16-16 --save experiments/celebahq256 77 | ``` 78 | 79 | ## Pretrained Models 80 | 81 | Model checkpoints can be downloaded from [releases](https://github.com/rtqichen/residual-flows/releases/latest). 82 | 83 | Use the argument `--resume [checkpt.pth]` to evaluate or sample from the model. 84 | 85 | Each checkpoint contains two sets of parameters, one from training and one containing the exponential moving average (EMA) accumulated over the course of training. Scripts will automatically use the EMA parameters for evaluation and sampling. 86 | 87 | ## BibTeX 88 | ``` 89 | @inproceedings{chen2019residualflows, 90 | title={Residual Flows for Invertible Generative Modeling}, 91 | author={Chen, Ricky T. Q. and Behrmann, Jens and Duvenaud, David and Jacobsen, J{\"{o}}rn{-}Henrik}, 92 | booktitle = {Advances in Neural Information Processing Systems}, 93 | year={2019} 94 | } 95 | ``` 96 | -------------------------------------------------------------------------------- /assets/celebahq_resflow.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtqichen/residual-flows/8170138c850a3574319491d97093bc860ce4922d/assets/celebahq_resflow.jpg -------------------------------------------------------------------------------- /assets/flow_comparison.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtqichen/residual-flows/8170138c850a3574319491d97093bc860ce4922d/assets/flow_comparison.jpg -------------------------------------------------------------------------------- /preprocessing/convert_to_pth.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | import numpy as np 4 | import torch 5 | 6 | img = torch.tensor(np.load(sys.argv[1])) 7 | img = img.permute(0, 3, 1, 2) 8 | torch.save(img, re.sub('.npy$', '.pth', sys.argv[1])) 9 | -------------------------------------------------------------------------------- /preprocessing/create_imagenet_benchmark_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run the following commands in ~ before running this file 3 | wget http://image-net.org/small/train_64x64.tar 4 | wget http://image-net.org/small/valid_64x64.tar 5 | tar -xvf train_64x64.tar 6 | tar -xvf valid_64x64.tar 7 | wget http://image-net.org/small/train_32x32.tar 8 | wget http://image-net.org/small/valid_32x32.tar 9 | tar -xvf train_32x32.tar 10 | tar -xvf valid_32x32.tar 11 | """ 12 | 13 | import numpy as np 14 | import scipy.ndimage 15 | import os 16 | from os import listdir 17 | from os.path import isfile, join 18 | from tqdm import tqdm 19 | 20 | 21 | def convert_path_to_npy(*, path='train_64x64', outfile='train_64x64.npy'): 22 | assert isinstance(path, str), "Expected a string input for the path" 23 | assert os.path.exists(path), "Input path doesn't exist" 24 | files = [f for f in listdir(path) if isfile(join(path, f))] 25 | print('Number of valid images is:', len(files)) 26 | imgs = [] 27 | for i in tqdm(range(len(files))): 28 | img = scipy.ndimage.imread(join(path, files[i])) 29 | img = img.astype('uint8') 30 | assert np.max(img) <= 255 31 | assert np.min(img) >= 0 32 | assert img.dtype == 'uint8' 33 | assert isinstance(img, np.ndarray) 34 | imgs.append(img) 35 | resolution_x, resolution_y = img.shape[0], img.shape[1] 36 | imgs = np.asarray(imgs).astype('uint8') 37 | assert imgs.shape[1:] == (resolution_x, resolution_y, 3) 38 | assert np.max(imgs) <= 255 39 | assert np.min(imgs) >= 0 40 | print('Total number of images is:', imgs.shape[0]) 41 | print('All assertions done, dumping into npy file') 42 | np.save(outfile, imgs) 43 | 44 | 45 | if __name__ == '__main__': 46 | convert_path_to_npy(path='train_64x64', outfile='train_64x64.npy') 47 | convert_path_to_npy(path='valid_64x64', outfile='valid_64x64.npy') 48 | convert_path_to_npy(path='train_32x32', outfile='train_32x32.npy') 49 | convert_path_to_npy(path='valid_32x32', outfile='valid_32x32.npy') 50 | -------------------------------------------------------------------------------- /preprocessing/extract_celeba_from_tfrecords.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import torch 4 | 5 | sess = tf.InteractiveSession() 6 | 7 | train_imgs = [] 8 | 9 | print('Reading from training set...', flush=True) 10 | for i in range(120): 11 | tfr = 'data/celebahq/celeba-tfr/train/train-r08-s-{:04d}-of-0120.tfrecords'.format(i) 12 | print(tfr, flush=True) 13 | 14 | record_iterator = tf.python_io.tf_record_iterator(tfr) 15 | 16 | for string_record in record_iterator: 17 | example = tf.train.Example() 18 | example.ParseFromString(string_record) 19 | 20 | image_bytes = example.features.feature['data'].bytes_list.value[0] 21 | 22 | img = tf.decode_raw(image_bytes, tf.uint8) 23 | img = tf.reshape(img, [256, 256, 3]) 24 | img = img.eval() 25 | 26 | train_imgs.append(img) 27 | 28 | train_imgs = np.stack(train_imgs) 29 | train_imgs = torch.tensor(train_imgs).permute(0, 3, 1, 2) 30 | torch.save(train_imgs, 'data/celebahq/celeba256_train.pth') 31 | 32 | validation_imgs = [] 33 | for i in range(40): 34 | tfr = 'data/celebahq/celeba-tfr/validation/validation-r08-s-{:04d}-of-0040.tfrecords'.format(i) 35 | print(tfr, flush=True) 36 | 37 | record_iterator = tf.python_io.tf_record_iterator(tfr) 38 | 39 | for string_record in record_iterator: 40 | example = tf.train.Example() 41 | example.ParseFromString(string_record) 42 | 43 | image_bytes = example.features.feature['data'].bytes_list.value[0] 44 | 45 | img = tf.decode_raw(image_bytes, tf.uint8) 46 | img = tf.reshape(img, [256, 256, 3]) 47 | img = img.eval() 48 | 49 | validation_imgs.append(img) 50 | 51 | validation_imgs = np.stack(validation_imgs) 52 | validation_imgs = torch.tensor(validation_imgs).permute(0, 3, 1, 2) 53 | torch.save(validation_imgs, 'data/celebahq/celeba256_validation.pth') 54 | -------------------------------------------------------------------------------- /qualitative_samples.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | from tqdm import tqdm 4 | 5 | import torch 6 | import torchvision.transforms as transforms 7 | from torchvision.utils import save_image 8 | import torchvision.datasets as vdsets 9 | 10 | from resflows.iresnet import ACT_FNS, ResidualFlow 11 | import resflows.datasets as datasets 12 | import resflows.utils as utils 13 | import resflows.layers as layers 14 | import resflows.layers.base as base_layers 15 | 16 | # Arguments 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | '--data', type=str, default='cifar10', choices=[ 20 | 'mnist', 21 | 'cifar10', 22 | 'celeba', 23 | 'celebahq', 24 | 'celeba_5bit', 25 | 'imagenet32', 26 | 'imagenet64', 27 | ] 28 | ) 29 | parser.add_argument('--dataroot', type=str, default='data') 30 | parser.add_argument('--imagesize', type=int, default=32) 31 | parser.add_argument('--nbits', type=int, default=8) # Only used for celebahq. 32 | 33 | # Sampling parameters. 34 | parser.add_argument('--real', type=eval, choices=[True, False], default=False) 35 | parser.add_argument('--nrow', type=int, default=10) 36 | parser.add_argument('--ncol', type=int, default=10) 37 | parser.add_argument('--temp', type=float, default=1.0) 38 | parser.add_argument('--nbatches', type=int, default=5) 39 | parser.add_argument('--save-each', type=eval, choices=[True, False], default=False) 40 | 41 | parser.add_argument('--block', type=str, choices=['resblock', 'coupling'], default='resblock') 42 | 43 | parser.add_argument('--coeff', type=float, default=0.98) 44 | parser.add_argument('--vnorms', type=str, default='2222') 45 | parser.add_argument('--n-lipschitz-iters', type=int, default=None) 46 | parser.add_argument('--sn-tol', type=float, default=1e-3) 47 | parser.add_argument('--learn-p', type=eval, choices=[True, False], default=False) 48 | 49 | parser.add_argument('--n-power-series', type=int, default=None) 50 | parser.add_argument('--factor-out', type=eval, choices=[True, False], default=False) 51 | parser.add_argument('--n-dist', choices=['geometric', 'poisson'], default='geometric') 52 | parser.add_argument('--n-samples', type=int, default=1) 53 | parser.add_argument('--n-exact-terms', type=int, default=2) 54 | parser.add_argument('--var-reduc-lr', type=float, default=0) 55 | parser.add_argument('--neumann-grad', type=eval, choices=[True, False], default=True) 56 | parser.add_argument('--mem-eff', type=eval, choices=[True, False], default=True) 57 | 58 | parser.add_argument('--act', type=str, choices=ACT_FNS.keys(), default='swish') 59 | parser.add_argument('--idim', type=int, default=512) 60 | parser.add_argument('--nblocks', type=str, default='16-16-16') 61 | parser.add_argument('--squeeze-first', type=eval, default=False, choices=[True, False]) 62 | parser.add_argument('--actnorm', type=eval, default=True, choices=[True, False]) 63 | parser.add_argument('--fc-actnorm', type=eval, default=False, choices=[True, False]) 64 | parser.add_argument('--batchnorm', type=eval, default=False, choices=[True, False]) 65 | parser.add_argument('--dropout', type=float, default=0.) 66 | parser.add_argument('--fc', type=eval, default=False, choices=[True, False]) 67 | parser.add_argument('--kernels', type=str, default='3-1-3') 68 | parser.add_argument('--quadratic', type=eval, choices=[True, False], default=False) 69 | parser.add_argument('--fc-end', type=eval, choices=[True, False], default=True) 70 | parser.add_argument('--fc-idim', type=int, default=128) 71 | parser.add_argument('--preact', type=eval, choices=[True, False], default=True) 72 | parser.add_argument('--padding', type=int, default=0) 73 | parser.add_argument('--first-resblock', type=eval, choices=[True, False], default=True) 74 | 75 | parser.add_argument('--task', type=str, choices=['density'], default='density') 76 | parser.add_argument('--rcrop-pad-mode', type=str, choices=['constant', 'reflect'], default='reflect') 77 | parser.add_argument('--padding-dist', type=str, choices=['uniform', 'gaussian'], default='uniform') 78 | 79 | parser.add_argument('--resume', type=str, required=True) 80 | parser.add_argument('--nworkers', type=int, default=4) 81 | args = parser.parse_args() 82 | 83 | W = args.ncol 84 | H = args.nrow 85 | 86 | args.batchsize = W * H 87 | args.val_batchsize = W * H 88 | 89 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 90 | 91 | if device.type == 'cuda': 92 | print('Found {} CUDA devices.'.format(torch.cuda.device_count())) 93 | for i in range(torch.cuda.device_count()): 94 | props = torch.cuda.get_device_properties(i) 95 | print('{} \t Memory: {:.2f}GB'.format(props.name, props.total_memory / (1024**3))) 96 | else: 97 | print('WARNING: Using device {}'.format(device)) 98 | 99 | 100 | def geometric_logprob(ns, p): 101 | return torch.log(1 - p + 1e-10) * (ns - 1) + torch.log(p + 1e-10) 102 | 103 | 104 | def standard_normal_sample(size): 105 | return torch.randn(size) 106 | 107 | 108 | def standard_normal_logprob(z): 109 | logZ = -0.5 * math.log(2 * math.pi) 110 | return logZ - z.pow(2) / 2 111 | 112 | 113 | def count_parameters(model): 114 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 115 | 116 | 117 | def add_noise(x, nvals=255): 118 | """ 119 | [0, 1] -> [0, nvals] -> add noise -> [0, 1] 120 | """ 121 | noise = x.new().resize_as_(x).uniform_() 122 | x = x * nvals + noise 123 | x = x / (nvals + 1) 124 | return x 125 | 126 | 127 | def update_lr(optimizer, itr): 128 | iter_frac = min(float(itr + 1) / max(args.warmup_iters, 1), 1.0) 129 | lr = args.lr * iter_frac 130 | for param_group in optimizer.param_groups: 131 | param_group["lr"] = lr 132 | 133 | 134 | def add_padding(x): 135 | if args.padding > 0: 136 | u = x.new_empty(x.shape[0], args.padding, x.shape[2], x.shape[3]).uniform_() 137 | logpu = torch.zeros_like(u).sum([1, 2, 3]) 138 | return torch.cat([u, x], dim=1), logpu 139 | else: 140 | return x, torch.zeros(x.shape[0]).to(x) 141 | 142 | 143 | def remove_padding(x): 144 | if args.padding > 0: 145 | return x[:, args.padding:, :, :] 146 | else: 147 | return x 148 | 149 | 150 | def reduce_bits(x): 151 | if args.nbits < 8: 152 | x = x * 255 153 | x = torch.floor(x / 2**(8 - args.nbits)) 154 | x = x / 2**args.nbits 155 | return x 156 | 157 | 158 | def update_lipschitz(model): 159 | for m in model.modules(): 160 | if isinstance(m, base_layers.SpectralNormConv2d) or isinstance(m, base_layers.SpectralNormLinear): 161 | m.compute_weight(update=True) 162 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear): 163 | m.compute_weight(update=True) 164 | 165 | 166 | print('Loading dataset {}'.format(args.data), flush=True) 167 | # Dataset and hyperparameters 168 | if args.data == 'cifar10': 169 | im_dim = 3 170 | n_classes = 10 171 | 172 | if args.task in ['classification', 'hybrid']: 173 | 174 | if args.real: 175 | 176 | # Classification-specific preprocessing. 177 | transform_train = transforms.Compose([ 178 | transforms.Resize(args.imagesize), 179 | transforms.RandomCrop(32, padding=4, padding_mode=args.rcrop_pad_mode), 180 | transforms.RandomHorizontalFlip(), 181 | transforms.ToTensor(), 182 | add_noise, 183 | ]) 184 | 185 | transform_test = transforms.Compose([ 186 | transforms.Resize(args.imagesize), 187 | transforms.ToTensor(), 188 | add_noise, 189 | ]) 190 | 191 | # Remove the logit transform. 192 | init_layer = layers.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 193 | else: 194 | if args.real: 195 | transform_train = transforms.Compose([ 196 | transforms.Resize(args.imagesize), 197 | transforms.RandomHorizontalFlip(), 198 | transforms.ToTensor(), 199 | add_noise, 200 | ]) 201 | transform_test = transforms.Compose([ 202 | transforms.Resize(args.imagesize), 203 | transforms.ToTensor(), 204 | add_noise, 205 | ]) 206 | init_layer = layers.LogitTransform(0.05) 207 | if args.real: 208 | train_loader = torch.utils.data.DataLoader( 209 | datasets.CIFAR10(args.dataroot, train=True, transform=transform_train), 210 | batch_size=args.batchsize, 211 | shuffle=True, 212 | num_workers=args.nworkers, 213 | ) 214 | test_loader = torch.utils.data.DataLoader( 215 | datasets.CIFAR10(args.dataroot, train=False, transform=transform_test), 216 | batch_size=args.val_batchsize, 217 | shuffle=False, 218 | num_workers=args.nworkers, 219 | ) 220 | elif args.data == 'mnist': 221 | im_dim = 1 222 | init_layer = layers.LogitTransform(1e-6) 223 | n_classes = 10 224 | 225 | if args.real: 226 | train_loader = torch.utils.data.DataLoader( 227 | datasets.MNIST( 228 | args.dataroot, train=True, transform=transforms.Compose([ 229 | transforms.Resize(args.imagesize), 230 | transforms.ToTensor(), 231 | add_noise, 232 | ]) 233 | ), 234 | batch_size=args.batchsize, 235 | shuffle=True, 236 | num_workers=args.nworkers, 237 | ) 238 | test_loader = torch.utils.data.DataLoader( 239 | datasets.MNIST( 240 | args.dataroot, train=False, transform=transforms.Compose([ 241 | transforms.Resize(args.imagesize), 242 | transforms.ToTensor(), 243 | add_noise, 244 | ]) 245 | ), 246 | batch_size=args.val_batchsize, 247 | shuffle=False, 248 | num_workers=args.nworkers, 249 | ) 250 | elif args.data == 'svhn': 251 | im_dim = 3 252 | init_layer = layers.LogitTransform(0.05) 253 | n_classes = 10 254 | 255 | if args.real: 256 | train_loader = torch.utils.data.DataLoader( 257 | vdsets.SVHN( 258 | args.dataroot, split='train', download=True, transform=transforms.Compose([ 259 | transforms.Resize(args.imagesize), 260 | transforms.RandomCrop(32, padding=4, padding_mode=args.rcrop_pad_mode), 261 | transforms.ToTensor(), 262 | add_noise, 263 | ]) 264 | ), 265 | batch_size=args.batchsize, 266 | shuffle=True, 267 | num_workers=args.nworkers, 268 | ) 269 | test_loader = torch.utils.data.DataLoader( 270 | vdsets.SVHN( 271 | args.dataroot, split='test', download=True, transform=transforms.Compose([ 272 | transforms.Resize(args.imagesize), 273 | transforms.ToTensor(), 274 | add_noise, 275 | ]) 276 | ), 277 | batch_size=args.val_batchsize, 278 | shuffle=False, 279 | num_workers=args.nworkers, 280 | ) 281 | elif args.data == 'celebahq': 282 | im_dim = 3 283 | init_layer = layers.LogitTransform(0.05) 284 | 285 | if args.real: 286 | train_loader = torch.utils.data.DataLoader( 287 | datasets.CelebAHQ( 288 | train=True, transform=transforms.Compose([ 289 | transforms.ToPILImage(), 290 | transforms.RandomHorizontalFlip(), 291 | transforms.ToTensor(), 292 | reduce_bits, 293 | lambda x: add_noise(x, nvals=2**args.nbits), 294 | ]) 295 | ), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers 296 | ) 297 | test_loader = torch.utils.data.DataLoader( 298 | datasets.CelebAHQ( 299 | train=False, transform=transforms.Compose([ 300 | reduce_bits, 301 | lambda x: add_noise(x, nvals=2**args.nbits), 302 | ]) 303 | ), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers 304 | ) 305 | elif args.data == 'celeba_5bit': 306 | im_dim = 3 307 | init_layer = layers.LogitTransform(0.05) 308 | if args.imagesize != 64: 309 | print('Changing image size to 64.') 310 | args.imagesize = 64 311 | 312 | if args.real: 313 | train_loader = torch.utils.data.DataLoader( 314 | datasets.CelebA5bit( 315 | train=True, transform=transforms.Compose([ 316 | transforms.ToPILImage(), 317 | transforms.RandomHorizontalFlip(), 318 | transforms.ToTensor(), 319 | lambda x: add_noise(x, nvals=32), 320 | ]) 321 | ), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers 322 | ) 323 | test_loader = torch.utils.data.DataLoader( 324 | datasets.CelebA5bit(train=False, transform=transforms.Compose([ 325 | lambda x: add_noise(x, nvals=32), 326 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers 327 | ) 328 | elif args.data == 'imagenet32': 329 | im_dim = 3 330 | init_layer = layers.LogitTransform(0.05) 331 | if args.imagesize != 32: 332 | print('Changing image size to 32.') 333 | args.imagesize = 32 334 | 335 | if args.real: 336 | train_loader = torch.utils.data.DataLoader( 337 | datasets.Imagenet32(train=True, transform=transforms.Compose([ 338 | add_noise, 339 | ])), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers 340 | ) 341 | test_loader = torch.utils.data.DataLoader( 342 | datasets.Imagenet32(train=False, transform=transforms.Compose([ 343 | add_noise, 344 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers 345 | ) 346 | elif args.data == 'imagenet64': 347 | im_dim = 3 348 | init_layer = layers.LogitTransform(0.05) 349 | if args.imagesize != 64: 350 | print('Changing image size to 64.') 351 | args.imagesize = 64 352 | 353 | if args.real: 354 | train_loader = torch.utils.data.DataLoader( 355 | datasets.Imagenet64(train=True, transform=transforms.Compose([ 356 | add_noise, 357 | ])), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers 358 | ) 359 | test_loader = torch.utils.data.DataLoader( 360 | datasets.Imagenet64(train=False, transform=transforms.Compose([ 361 | add_noise, 362 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers 363 | ) 364 | 365 | if args.task in ['classification', 'hybrid']: 366 | try: 367 | n_classes 368 | except NameError: 369 | raise ValueError('Cannot perform classification with {}'.format(args.data)) 370 | else: 371 | n_classes = 1 372 | 373 | print('Dataset loaded.', flush=True) 374 | print('Creating model.', flush=True) 375 | 376 | input_size = (args.batchsize, im_dim + args.padding, args.imagesize, args.imagesize) 377 | 378 | if args.squeeze_first: 379 | input_size = (input_size[0], input_size[1] * 4, input_size[2] // 2, input_size[3] // 2) 380 | squeeze_layer = layers.SqueezeLayer(2) 381 | 382 | # Model 383 | model = ResidualFlow( 384 | input_size, 385 | n_blocks=list(map(int, args.nblocks.split('-'))), 386 | intermediate_dim=args.idim, 387 | factor_out=args.factor_out, 388 | quadratic=args.quadratic, 389 | init_layer=init_layer, 390 | actnorm=args.actnorm, 391 | fc_actnorm=args.fc_actnorm, 392 | batchnorm=args.batchnorm, 393 | dropout=args.dropout, 394 | fc=args.fc, 395 | coeff=args.coeff, 396 | vnorms=args.vnorms, 397 | n_lipschitz_iters=args.n_lipschitz_iters, 398 | sn_atol=args.sn_tol, 399 | sn_rtol=args.sn_tol, 400 | n_power_series=args.n_power_series, 401 | n_dist=args.n_dist, 402 | n_samples=args.n_samples, 403 | kernels=args.kernels, 404 | activation_fn=args.act, 405 | fc_end=args.fc_end, 406 | fc_idim=args.fc_idim, 407 | n_exact_terms=args.n_exact_terms, 408 | preact=args.preact, 409 | neumann_grad=args.neumann_grad, 410 | grad_in_forward=args.mem_eff, 411 | first_resblock=args.first_resblock, 412 | learn_p=args.learn_p, 413 | block_type=args.block, 414 | ) 415 | 416 | model.to(device) 417 | 418 | print('Initializing model.', flush=True) 419 | 420 | with torch.no_grad(): 421 | x = torch.rand(1, *input_size[1:]).to(device) 422 | model(x) 423 | print('Restoring from checkpoint.', flush=True) 424 | checkpt = torch.load(args.resume) 425 | state = model.state_dict() 426 | model.load_state_dict(checkpt['state_dict'], strict=True) 427 | 428 | ema = utils.ExponentialMovingAverage(model) 429 | ema.set(checkpt['ema']) 430 | ema.swap() 431 | 432 | print(model, flush=True) 433 | 434 | model.eval() 435 | print('Updating lipschitz.', flush=True) 436 | update_lipschitz(model) 437 | 438 | 439 | def visualize(model): 440 | utils.makedirs('{}_imgs_t{}'.format(args.data, args.temp)) 441 | 442 | with torch.no_grad(): 443 | 444 | for i in tqdm(range(args.nbatches)): 445 | # random samples 446 | rand_z = torch.randn(args.batchsize, (im_dim + args.padding) * args.imagesize * args.imagesize).to(device) 447 | rand_z = rand_z * args.temp 448 | fake_imgs = model(rand_z, inverse=True).view(-1, *input_size[1:]) 449 | if args.squeeze_first: fake_imgs = squeeze_layer.inverse(fake_imgs) 450 | fake_imgs = remove_padding(fake_imgs) 451 | fake_imgs = fake_imgs.view(-1, im_dim, args.imagesize, args.imagesize) 452 | fake_imgs = fake_imgs.cpu() 453 | 454 | if args.save_each: 455 | for j in range(fake_imgs.shape[0]): 456 | save_image( 457 | fake_imgs[j], '{}_imgs_t{}/{}.png'.format(args.data, args.temp, args.batchsize * i + j), nrow=1, 458 | padding=0, range=(0, 1), pad_value=0 459 | ) 460 | else: 461 | save_image( 462 | fake_imgs, 'imgs/{}_t{}_samples{}.png'.format(args.data, args.temp, i), nrow=W, padding=2, 463 | range=(0, 1), pad_value=1 464 | ) 465 | 466 | 467 | real_imgs = test_loader.__iter__().__next__()[0] if args.real else None 468 | if args.real: 469 | real_imgs = test_loader.__iter__().__next__()[0] 470 | save_image( 471 | real_imgs.cpu().float(), 'imgs/{}_real.png'.format(args.data), nrow=W, padding=2, range=(0, 1), pad_value=1 472 | ) 473 | 474 | visualize(model) 475 | -------------------------------------------------------------------------------- /resflows/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.datasets as vdsets 3 | 4 | 5 | class Dataset(object): 6 | 7 | def __init__(self, loc, transform=None, in_mem=True): 8 | self.in_mem = in_mem 9 | self.dataset = torch.load(loc) 10 | if in_mem: self.dataset = self.dataset.float().div(255) 11 | self.transform = transform 12 | 13 | def __len__(self): 14 | return self.dataset.size(0) 15 | 16 | @property 17 | def ndim(self): 18 | return self.dataset.size(1) 19 | 20 | def __getitem__(self, index): 21 | x = self.dataset[index] 22 | if not self.in_mem: x = x.float().div(255) 23 | x = self.transform(x) if self.transform is not None else x 24 | return x, 0 25 | 26 | 27 | class MNIST(object): 28 | 29 | def __init__(self, dataroot, train=True, transform=None): 30 | self.mnist = vdsets.MNIST(dataroot, train=train, download=True, transform=transform) 31 | 32 | def __len__(self): 33 | return len(self.mnist) 34 | 35 | @property 36 | def ndim(self): 37 | return 1 38 | 39 | def __getitem__(self, index): 40 | return self.mnist[index] 41 | 42 | 43 | class CIFAR10(object): 44 | 45 | def __init__(self, dataroot, train=True, transform=None): 46 | self.cifar10 = vdsets.CIFAR10(dataroot, train=train, download=True, transform=transform) 47 | 48 | def __len__(self): 49 | return len(self.cifar10) 50 | 51 | @property 52 | def ndim(self): 53 | return 3 54 | 55 | def __getitem__(self, index): 56 | return self.cifar10[index] 57 | 58 | 59 | class CelebA5bit(object): 60 | 61 | LOC = 'data/celebahq64_5bit/celeba_full_64x64_5bit.pth' 62 | 63 | def __init__(self, train=True, transform=None): 64 | self.dataset = torch.load(self.LOC).float().div(31) 65 | if not train: 66 | self.dataset = self.dataset[:5000] 67 | self.transform = transform 68 | 69 | def __len__(self): 70 | return self.dataset.size(0) 71 | 72 | @property 73 | def ndim(self): 74 | return self.dataset.size(1) 75 | 76 | def __getitem__(self, index): 77 | x = self.dataset[index] 78 | x = self.transform(x) if self.transform is not None else x 79 | return x, 0 80 | 81 | 82 | class CelebAHQ(Dataset): 83 | TRAIN_LOC = 'data/celebahq/celeba256_train.pth' 84 | TEST_LOC = 'data/celebahq/celeba256_validation.pth' 85 | 86 | def __init__(self, train=True, transform=None): 87 | return super(CelebAHQ, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform) 88 | 89 | 90 | class Imagenet32(Dataset): 91 | TRAIN_LOC = 'data/imagenet32/train_32x32.pth' 92 | TEST_LOC = 'data/imagenet32/valid_32x32.pth' 93 | 94 | def __init__(self, train=True, transform=None): 95 | return super(Imagenet32, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform) 96 | 97 | 98 | class Imagenet64(Dataset): 99 | TRAIN_LOC = 'data/imagenet64/train_64x64.pth' 100 | TEST_LOC = 'data/imagenet64/valid_64x64.pth' 101 | 102 | def __init__(self, train=True, transform=None): 103 | return super(Imagenet64, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform, in_mem=False) 104 | -------------------------------------------------------------------------------- /resflows/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .act_norm import * 2 | from .container import * 3 | from .coupling import * 4 | from .elemwise import * 5 | from .iresblock import * 6 | from .normalization import * 7 | from .squeeze import * 8 | from .glow import * 9 | -------------------------------------------------------------------------------- /resflows/layers/act_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | 5 | __all__ = ['ActNorm1d', 'ActNorm2d'] 6 | 7 | 8 | class ActNormNd(nn.Module): 9 | 10 | def __init__(self, num_features, eps=1e-12): 11 | super(ActNormNd, self).__init__() 12 | self.num_features = num_features 13 | self.eps = eps 14 | self.weight = Parameter(torch.Tensor(num_features)) 15 | self.bias = Parameter(torch.Tensor(num_features)) 16 | self.register_buffer('initialized', torch.tensor(0)) 17 | 18 | @property 19 | def shape(self): 20 | raise NotImplementedError 21 | 22 | def forward(self, x, logpx=None): 23 | c = x.size(1) 24 | 25 | if not self.initialized: 26 | with torch.no_grad(): 27 | # compute batch statistics 28 | x_t = x.transpose(0, 1).contiguous().view(c, -1) 29 | batch_mean = torch.mean(x_t, dim=1) 30 | batch_var = torch.var(x_t, dim=1) 31 | 32 | # for numerical issues 33 | batch_var = torch.max(batch_var, torch.tensor(0.2).to(batch_var)) 34 | 35 | self.bias.data.copy_(-batch_mean) 36 | self.weight.data.copy_(-0.5 * torch.log(batch_var)) 37 | self.initialized.fill_(1) 38 | 39 | bias = self.bias.view(*self.shape).expand_as(x) 40 | weight = self.weight.view(*self.shape).expand_as(x) 41 | 42 | y = (x + bias) * torch.exp(weight) 43 | 44 | if logpx is None: 45 | return y 46 | else: 47 | return y, logpx - self._logdetgrad(x) 48 | 49 | def inverse(self, y, logpy=None): 50 | assert self.initialized 51 | bias = self.bias.view(*self.shape).expand_as(y) 52 | weight = self.weight.view(*self.shape).expand_as(y) 53 | 54 | x = y * torch.exp(-weight) - bias 55 | 56 | if logpy is None: 57 | return x 58 | else: 59 | return x, logpy + self._logdetgrad(x) 60 | 61 | def _logdetgrad(self, x): 62 | return self.weight.view(*self.shape).expand(*x.size()).contiguous().view(x.size(0), -1).sum(1, keepdim=True) 63 | 64 | def __repr__(self): 65 | return ('{name}({num_features})'.format(name=self.__class__.__name__, **self.__dict__)) 66 | 67 | 68 | class ActNorm1d(ActNormNd): 69 | 70 | @property 71 | def shape(self): 72 | return [1, -1] 73 | 74 | 75 | class ActNorm2d(ActNormNd): 76 | 77 | @property 78 | def shape(self): 79 | return [1, -1, 1, 1] 80 | -------------------------------------------------------------------------------- /resflows/layers/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .lipschitz import * 3 | from .mixed_lipschitz import * 4 | -------------------------------------------------------------------------------- /resflows/layers/base/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Identity(nn.Module): 7 | 8 | def forward(self, x): 9 | return x 10 | 11 | 12 | class FullSort(nn.Module): 13 | 14 | def forward(self, x): 15 | return torch.sort(x, 1)[0] 16 | 17 | 18 | class MaxMin(nn.Module): 19 | 20 | def forward(self, x): 21 | b, d = x.shape 22 | max_vals = torch.max(x.view(b, d // 2, 2), 2)[0] 23 | min_vals = torch.min(x.view(b, d // 2, 2), 2)[0] 24 | return torch.cat([max_vals, min_vals], 1) 25 | 26 | 27 | class LipschitzCube(nn.Module): 28 | 29 | def forward(self, x): 30 | return (x >= 1).to(x) * (x - 2 / 3) + (x <= -1).to(x) * (x + 2 / 3) + ((x > -1) * (x < 1)).to(x) * x**3 / 3 31 | 32 | 33 | class SwishFn(torch.autograd.Function): 34 | 35 | @staticmethod 36 | def forward(ctx, x, beta): 37 | beta_sigm = torch.sigmoid(beta * x) 38 | output = x * beta_sigm 39 | ctx.save_for_backward(x, output, beta) 40 | return output / 1.1 41 | 42 | @staticmethod 43 | def backward(ctx, grad_output): 44 | x, output, beta = ctx.saved_tensors 45 | beta_sigm = output / x 46 | grad_x = grad_output * (beta * output + beta_sigm * (1 - beta * output)) 47 | grad_beta = torch.sum(grad_output * (x * output - output * output)).expand_as(beta) 48 | return grad_x / 1.1, grad_beta / 1.1 49 | 50 | 51 | class Swish(nn.Module): 52 | 53 | def __init__(self): 54 | super(Swish, self).__init__() 55 | self.beta = nn.Parameter(torch.tensor([0.5])) 56 | 57 | def forward(self, x): 58 | return (x * torch.sigmoid_(x * F.softplus(self.beta))).div_(1.1) 59 | 60 | 61 | if __name__ == '__main__': 62 | 63 | m = Swish() 64 | xx = torch.linspace(-5, 5, 1000).requires_grad_(True) 65 | yy = m(xx) 66 | dd, dbeta = torch.autograd.grad(yy.sum() * 2, [xx, m.beta]) 67 | 68 | import matplotlib.pyplot as plt 69 | 70 | plt.plot(xx.detach().numpy(), yy.detach().numpy(), label='Func') 71 | plt.plot(xx.detach().numpy(), dd.detach().numpy(), label='Deriv') 72 | plt.plot(xx.detach().numpy(), torch.max(dd.detach().abs() - 1, torch.zeros_like(dd)).numpy(), label='|Deriv| > 1') 73 | plt.legend() 74 | plt.tight_layout() 75 | plt.show() 76 | -------------------------------------------------------------------------------- /resflows/layers/base/lipschitz.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | 7 | from .utils import _pair 8 | from .mixed_lipschitz import InducedNormLinear, InducedNormConv2d 9 | 10 | __all__ = ['SpectralNormLinear', 'SpectralNormConv2d', 'LopLinear', 'LopConv2d', 'get_linear', 'get_conv2d'] 11 | 12 | 13 | class SpectralNormLinear(nn.Module): 14 | 15 | def __init__( 16 | self, in_features, out_features, bias=True, coeff=0.97, n_iterations=None, atol=None, rtol=None, **unused_kwargs 17 | ): 18 | del unused_kwargs 19 | super(SpectralNormLinear, self).__init__() 20 | self.in_features = in_features 21 | self.out_features = out_features 22 | self.coeff = coeff 23 | self.n_iterations = n_iterations 24 | self.atol = atol 25 | self.rtol = rtol 26 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 27 | if bias: 28 | self.bias = nn.Parameter(torch.Tensor(out_features)) 29 | else: 30 | self.register_parameter('bias', None) 31 | self.reset_parameters() 32 | 33 | h, w = self.weight.shape 34 | self.register_buffer('scale', torch.tensor(0.)) 35 | self.register_buffer('u', F.normalize(self.weight.new_empty(h).normal_(0, 1), dim=0)) 36 | self.register_buffer('v', F.normalize(self.weight.new_empty(w).normal_(0, 1), dim=0)) 37 | self.compute_weight(True, 200) 38 | 39 | def reset_parameters(self): 40 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 41 | if self.bias is not None: 42 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 43 | bound = 1 / math.sqrt(fan_in) 44 | init.uniform_(self.bias, -bound, bound) 45 | 46 | def compute_weight(self, update=True, n_iterations=None, atol=None, rtol=None): 47 | n_iterations = self.n_iterations if n_iterations is None else n_iterations 48 | atol = self.atol if atol is None else atol 49 | rtol = self.rtol if rtol is None else atol 50 | 51 | if n_iterations is None and (atol is None or rtol is None): 52 | raise ValueError('Need one of n_iteration or (atol, rtol).') 53 | 54 | if n_iterations is None: 55 | n_iterations = 20000 56 | 57 | u = self.u 58 | v = self.v 59 | weight = self.weight 60 | if update: 61 | with torch.no_grad(): 62 | itrs_used = 0. 63 | for _ in range(n_iterations): 64 | old_v = v.clone() 65 | old_u = u.clone() 66 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v` 67 | # are the first left and right singular vectors. 68 | # This power iteration produces approximations of `u` and `v`. 69 | v = F.normalize(torch.mv(weight.t(), u), dim=0, out=v) 70 | u = F.normalize(torch.mv(weight, v), dim=0, out=u) 71 | itrs_used = itrs_used + 1 72 | if atol is not None and rtol is not None: 73 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5) 74 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5) 75 | tol_u = atol + rtol * torch.max(u) 76 | tol_v = atol + rtol * torch.max(v) 77 | if err_u < tol_u and err_v < tol_v: 78 | break 79 | if itrs_used > 0: 80 | u = u.clone() 81 | v = v.clone() 82 | 83 | sigma = torch.dot(u, torch.mv(weight, v)) 84 | with torch.no_grad(): 85 | self.scale.copy_(sigma) 86 | # soft normalization: only when sigma larger than coeff 87 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff) 88 | weight = weight / factor 89 | return weight 90 | 91 | def forward(self, input): 92 | weight = self.compute_weight(update=self.training) 93 | return F.linear(input, weight, self.bias) 94 | 95 | def extra_repr(self): 96 | return 'in_features={}, out_features={}, bias={}, coeff={}, n_iters={}, atol={}, rtol={}'.format( 97 | self.in_features, self.out_features, self.bias is not None, self.coeff, self.n_iterations, self.atol, 98 | self.rtol 99 | ) 100 | 101 | 102 | class SpectralNormConv2d(nn.Module): 103 | 104 | def __init__( 105 | self, in_channels, out_channels, kernel_size, stride, padding, bias=True, coeff=0.97, n_iterations=None, 106 | atol=None, rtol=None, **unused_kwargs 107 | ): 108 | del unused_kwargs 109 | super(SpectralNormConv2d, self).__init__() 110 | self.in_channels = in_channels 111 | self.out_channels = out_channels 112 | self.kernel_size = _pair(kernel_size) 113 | self.stride = _pair(stride) 114 | self.padding = _pair(padding) 115 | self.coeff = coeff 116 | self.n_iterations = n_iterations 117 | self.atol = atol 118 | self.rtol = rtol 119 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size)) 120 | if bias: 121 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 122 | else: 123 | self.register_parameter('bias', None) 124 | self.reset_parameters() 125 | self.initialized = False 126 | self.register_buffer('spatial_dims', torch.tensor([1., 1.])) 127 | self.register_buffer('scale', torch.tensor(0.)) 128 | 129 | def reset_parameters(self): 130 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 131 | if self.bias is not None: 132 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 133 | bound = 1 / math.sqrt(fan_in) 134 | init.uniform_(self.bias, -bound, bound) 135 | 136 | def _initialize_u_v(self): 137 | if self.kernel_size == (1, 1): 138 | self.register_buffer('u', F.normalize(self.weight.new_empty(self.out_channels).normal_(0, 1), dim=0)) 139 | self.register_buffer('v', F.normalize(self.weight.new_empty(self.in_channels).normal_(0, 1), dim=0)) 140 | else: 141 | c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item()) 142 | with torch.no_grad(): 143 | num_input_dim = c * h * w 144 | v = F.normalize(torch.randn(num_input_dim).to(self.weight), dim=0, eps=1e-12) 145 | # forward call to infer the shape 146 | u = F.conv2d(v.view(1, c, h, w), self.weight, stride=self.stride, padding=self.padding, bias=None) 147 | num_output_dim = u.shape[0] * u.shape[1] * u.shape[2] * u.shape[3] 148 | self.out_shape = u.shape 149 | # overwrite u with random init 150 | u = F.normalize(torch.randn(num_output_dim).to(self.weight), dim=0, eps=1e-12) 151 | 152 | self.register_buffer('u', u) 153 | self.register_buffer('v', v) 154 | 155 | def compute_weight(self, update=True, n_iterations=None): 156 | if not self.initialized: 157 | self._initialize_u_v() 158 | self.initialized = True 159 | 160 | if self.kernel_size == (1, 1): 161 | return self._compute_weight_1x1(update, n_iterations) 162 | else: 163 | return self._compute_weight_kxk(update, n_iterations) 164 | 165 | def _compute_weight_1x1(self, update=True, n_iterations=None, atol=None, rtol=None): 166 | n_iterations = self.n_iterations if n_iterations is None else n_iterations 167 | atol = self.atol if atol is None else atol 168 | rtol = self.rtol if rtol is None else atol 169 | 170 | if n_iterations is None and (atol is None or rtol is None): 171 | raise ValueError('Need one of n_iteration or (atol, rtol).') 172 | 173 | if n_iterations is None: 174 | n_iterations = 20000 175 | 176 | u = self.u 177 | v = self.v 178 | weight = self.weight.view(self.out_channels, self.in_channels) 179 | if update: 180 | with torch.no_grad(): 181 | itrs_used = 0 182 | for _ in range(n_iterations): 183 | old_v = v.clone() 184 | old_u = u.clone() 185 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v` 186 | # are the first left and right singular vectors. 187 | # This power iteration produces approximations of `u` and `v`. 188 | v = F.normalize(torch.mv(weight.t(), u), dim=0, out=v) 189 | u = F.normalize(torch.mv(weight, v), dim=0, out=u) 190 | itrs_used = itrs_used + 1 191 | if atol is not None and rtol is not None: 192 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5) 193 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5) 194 | tol_u = atol + rtol * torch.max(u) 195 | tol_v = atol + rtol * torch.max(v) 196 | if err_u < tol_u and err_v < tol_v: 197 | break 198 | if itrs_used > 0: 199 | u = u.clone() 200 | v = v.clone() 201 | 202 | sigma = torch.dot(u, torch.mv(weight, v)) 203 | with torch.no_grad(): 204 | self.scale.copy_(sigma) 205 | # soft normalization: only when sigma larger than coeff 206 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff) 207 | weight = weight / factor 208 | return weight.view(self.out_channels, self.in_channels, 1, 1) 209 | 210 | def _compute_weight_kxk(self, update=True, n_iterations=None, atol=None, rtol=None): 211 | n_iterations = self.n_iterations if n_iterations is None else n_iterations 212 | atol = self.atol if atol is None else atol 213 | rtol = self.rtol if rtol is None else atol 214 | 215 | if n_iterations is None and (atol is None or rtol is None): 216 | raise ValueError('Need one of n_iteration or (atol, rtol).') 217 | 218 | if n_iterations is None: 219 | n_iterations = 20000 220 | 221 | u = self.u 222 | v = self.v 223 | weight = self.weight 224 | c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item()) 225 | if update: 226 | with torch.no_grad(): 227 | itrs_used = 0 228 | for _ in range(n_iterations): 229 | old_u = u.clone() 230 | old_v = v.clone() 231 | v_s = F.conv_transpose2d( 232 | u.view(self.out_shape), weight, stride=self.stride, padding=self.padding, output_padding=0 233 | ) 234 | v = F.normalize(v_s.view(-1), dim=0, out=v) 235 | u_s = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None) 236 | u = F.normalize(u_s.view(-1), dim=0, out=u) 237 | itrs_used = itrs_used + 1 238 | if atol is not None and rtol is not None: 239 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5) 240 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5) 241 | tol_u = atol + rtol * torch.max(u) 242 | tol_v = atol + rtol * torch.max(v) 243 | if err_u < tol_u and err_v < tol_v: 244 | break 245 | if itrs_used > 0: 246 | u = u.clone() 247 | v = v.clone() 248 | 249 | weight_v = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None) 250 | weight_v = weight_v.view(-1) 251 | sigma = torch.dot(u.view(-1), weight_v) 252 | with torch.no_grad(): 253 | self.scale.copy_(sigma) 254 | # soft normalization: only when sigma larger than coeff 255 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff) 256 | weight = weight / factor 257 | return weight 258 | 259 | def forward(self, input): 260 | if not self.initialized: self.spatial_dims.copy_(torch.tensor(input.shape[2:4]).to(self.spatial_dims)) 261 | weight = self.compute_weight(update=self.training) 262 | return F.conv2d(input, weight, self.bias, self.stride, self.padding, 1, 1) 263 | 264 | def extra_repr(self): 265 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') 266 | if self.padding != (0,) * len(self.padding): 267 | s += ', padding={padding}' 268 | if self.bias is None: 269 | s += ', bias=False' 270 | s += ', coeff={}, n_iters={}, atol={}, rtol={}'.format(self.coeff, self.n_iterations, self.atol, self.rtol) 271 | return s.format(**self.__dict__) 272 | 273 | 274 | class LopLinear(nn.Linear): 275 | """Lipschitz constant defined using operator norms.""" 276 | 277 | def __init__( 278 | self, 279 | in_features, 280 | out_features, 281 | bias=True, 282 | coeff=0.97, 283 | domain=float('inf'), 284 | codomain=float('inf'), 285 | local_constraint=True, 286 | **unused_kwargs, 287 | ): 288 | del unused_kwargs 289 | super(LopLinear, self).__init__(in_features, out_features, bias) 290 | self.coeff = coeff 291 | self.domain = domain 292 | self.codomain = codomain 293 | self.local_constraint = local_constraint 294 | max_across_input_dims, self.norm_type = operator_norm_settings(self.domain, self.codomain) 295 | self.max_across_dim = 1 if max_across_input_dims else 0 296 | self.register_buffer('scale', torch.tensor(0.)) 297 | 298 | def compute_weight(self): 299 | scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim) 300 | if not self.local_constraint: scale = scale.max() 301 | with torch.no_grad(): 302 | self.scale.copy_(scale.max()) 303 | 304 | # soft normalization 305 | factor = torch.max(torch.ones(1).to(self.weight), scale / self.coeff) 306 | 307 | return self.weight / factor 308 | 309 | def forward(self, input): 310 | weight = self.compute_weight() 311 | return F.linear(input, weight, self.bias) 312 | 313 | def extra_repr(self): 314 | s = super(LopLinear, self).extra_repr() 315 | return s + ', coeff={}, domain={}, codomain={}, local={}'.format( 316 | self.coeff, self.domain, self.codomain, self.local_constraint 317 | ) 318 | 319 | 320 | class LopConv2d(nn.Conv2d): 321 | """Lipschitz constant defined using operator norms.""" 322 | 323 | def __init__( 324 | self, 325 | in_channels, 326 | out_channels, 327 | kernel_size, 328 | stride, 329 | padding, 330 | bias=True, 331 | coeff=0.97, 332 | domain=float('inf'), 333 | codomain=float('inf'), 334 | local_constraint=True, 335 | **unused_kwargs, 336 | ): 337 | del unused_kwargs 338 | super(LopConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, bias) 339 | self.coeff = coeff 340 | self.domain = domain 341 | self.codomain = codomain 342 | self.local_constraint = local_constraint 343 | max_across_input_dims, self.norm_type = operator_norm_settings(self.domain, self.codomain) 344 | self.max_across_dim = 1 if max_across_input_dims else 0 345 | self.register_buffer('scale', torch.tensor(0.)) 346 | 347 | def compute_weight(self): 348 | scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim) 349 | if not self.local_constraint: scale = scale.max() 350 | with torch.no_grad(): 351 | self.scale.copy_(scale.max()) 352 | 353 | # soft normalization 354 | factor = torch.max(torch.ones(1).to(self.weight.device), scale / self.coeff) 355 | 356 | return self.weight / factor 357 | 358 | def forward(self, input): 359 | weight = self.compute_weight() 360 | return F.conv2d(input, weight, self.bias, self.stride, self.padding, 1, 1) 361 | 362 | def extra_repr(self): 363 | s = super(LopConv2d, self).extra_repr() 364 | return s + ', coeff={}, domain={}, codomain={}, local={}'.format( 365 | self.coeff, self.domain, self.codomain, self.local_constraint 366 | ) 367 | 368 | 369 | class LipNormLinear(nn.Linear): 370 | """Lipschitz constant defined using operator norms.""" 371 | 372 | def __init__( 373 | self, 374 | in_features, 375 | out_features, 376 | bias=True, 377 | coeff=0.97, 378 | domain=float('inf'), 379 | codomain=float('inf'), 380 | local_constraint=True, 381 | **unused_kwargs, 382 | ): 383 | del unused_kwargs 384 | super(LipNormLinear, self).__init__(in_features, out_features, bias) 385 | self.coeff = coeff 386 | self.domain = domain 387 | self.codomain = codomain 388 | self.local_constraint = local_constraint 389 | max_across_input_dims, self.norm_type = operator_norm_settings(self.domain, self.codomain) 390 | self.max_across_dim = 1 if max_across_input_dims else 0 391 | 392 | # Initialize scale parameter. 393 | with torch.no_grad(): 394 | w_scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim) 395 | if not self.local_constraint: w_scale = w_scale.max() 396 | self.scale = nn.Parameter(_logit(w_scale / self.coeff)) 397 | 398 | def compute_weight(self): 399 | w_scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim) 400 | if not self.local_constraint: w_scale = w_scale.max() 401 | return self.weight / w_scale * torch.sigmoid(self.scale) * self.coeff 402 | 403 | def forward(self, input): 404 | weight = self.compute_weight() 405 | return F.linear(input, weight, self.bias) 406 | 407 | def extra_repr(self): 408 | s = super(LipNormLinear, self).extra_repr() 409 | return s + ', coeff={}, domain={}, codomain={}, local={}'.format( 410 | self.coeff, self.domain, self.codomain, self.local_constraint 411 | ) 412 | 413 | 414 | class LipNormConv2d(nn.Conv2d): 415 | """Lipschitz constant defined using operator norms.""" 416 | 417 | def __init__( 418 | self, 419 | in_channels, 420 | out_channels, 421 | kernel_size, 422 | stride, 423 | padding, 424 | bias=True, 425 | coeff=0.97, 426 | domain=float('inf'), 427 | codomain=float('inf'), 428 | local_constraint=True, 429 | **unused_kwargs, 430 | ): 431 | del unused_kwargs 432 | super(LipNormConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, bias) 433 | self.coeff = coeff 434 | self.domain = domain 435 | self.codomain = codomain 436 | self.local_constraint = local_constraint 437 | max_across_input_dims, self.norm_type = operator_norm_settings(self.domain, self.codomain) 438 | self.max_across_dim = 1 if max_across_input_dims else 0 439 | 440 | # Initialize scale parameter. 441 | with torch.no_grad(): 442 | w_scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim) 443 | if not self.local_constraint: w_scale = w_scale.max() 444 | self.scale = nn.Parameter(_logit(w_scale / self.coeff)) 445 | 446 | def compute_weight(self): 447 | w_scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim) 448 | if not self.local_constraint: w_scale = w_scale.max() 449 | return self.weight / w_scale * torch.sigmoid(self.scale) 450 | 451 | def forward(self, input): 452 | weight = self.compute_weight() 453 | return F.conv2d(input, weight, self.bias, self.stride, self.padding, 1, 1) 454 | 455 | def extra_repr(self): 456 | s = super(LipNormConv2d, self).extra_repr() 457 | return s + ', coeff={}, domain={}, codomain={}, local={}'.format( 458 | self.coeff, self.domain, self.codomain, self.local_constraint 459 | ) 460 | 461 | 462 | def _logit(p): 463 | p = torch.max(torch.ones(1) * 0.1, torch.min(torch.ones(1) * 0.9, p)) 464 | return torch.log(p + 1e-10) + torch.log(1 - p + 1e-10) 465 | 466 | 467 | def _norm_except_dim(w, norm_type, dim): 468 | if norm_type == 1 or norm_type == 2: 469 | return torch.norm_except_dim(w, norm_type, dim) 470 | elif norm_type == float('inf'): 471 | return _max_except_dim(w, dim) 472 | 473 | 474 | def _max_except_dim(input, dim): 475 | maxed = input 476 | for axis in range(input.ndimension() - 1, dim, -1): 477 | maxed, _ = maxed.max(axis, keepdim=True) 478 | for axis in range(dim - 1, -1, -1): 479 | maxed, _ = maxed.max(axis, keepdim=True) 480 | return maxed 481 | 482 | 483 | def operator_norm_settings(domain, codomain): 484 | if domain == 1 and codomain == 1: 485 | # maximum l1-norm of column 486 | max_across_input_dims = True 487 | norm_type = 1 488 | elif domain == 1 and codomain == 2: 489 | # maximum l2-norm of column 490 | max_across_input_dims = True 491 | norm_type = 2 492 | elif domain == 1 and codomain == float("inf"): 493 | # maximum l-inf norm of column 494 | max_across_input_dims = True 495 | norm_type = float("inf") 496 | elif domain == 2 and codomain == float("inf"): 497 | # maximum l2-norm of row 498 | max_across_input_dims = False 499 | norm_type = 2 500 | elif domain == float("inf") and codomain == float("inf"): 501 | # maximum l1-norm of row 502 | max_across_input_dims = False 503 | norm_type = 1 504 | else: 505 | raise ValueError('Unknown combination of domain "{}" and codomain "{}"'.format(domain, codomain)) 506 | 507 | return max_across_input_dims, norm_type 508 | 509 | 510 | def get_linear(in_features, out_features, bias=True, coeff=0.97, domain=None, codomain=None, **kwargs): 511 | _linear = InducedNormLinear 512 | if domain == 1: 513 | if codomain in [1, 2, float('inf')]: 514 | _linear = LopLinear 515 | elif codomain == float('inf'): 516 | if domain in [2, float('inf')]: 517 | _linear = LopLinear 518 | return _linear(in_features, out_features, bias, coeff, domain, codomain, **kwargs) 519 | 520 | 521 | def get_conv2d( 522 | in_channels, out_channels, kernel_size, stride, padding, bias=True, coeff=0.97, domain=None, codomain=None, **kwargs 523 | ): 524 | _conv2d = InducedNormConv2d 525 | if domain == 1: 526 | if codomain in [1, 2, float('inf')]: 527 | _conv2d = LopConv2d 528 | elif codomain == float('inf'): 529 | if domain in [2, float('inf')]: 530 | _conv2d = LopConv2d 531 | return _conv2d(in_channels, out_channels, kernel_size, stride, padding, bias, coeff, domain, codomain, **kwargs) 532 | -------------------------------------------------------------------------------- /resflows/layers/base/mixed_lipschitz.py: -------------------------------------------------------------------------------- 1 | import collections.abc as container_abcs 2 | from itertools import repeat 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | import torch.nn.functional as F 8 | 9 | __all__ = ['InducedNormLinear', 'InducedNormConv2d'] 10 | 11 | 12 | class InducedNormLinear(nn.Module): 13 | 14 | def __init__( 15 | self, in_features, out_features, bias=True, coeff=0.97, domain=2, codomain=2, n_iterations=None, atol=None, 16 | rtol=None, zero_init=False, **unused_kwargs 17 | ): 18 | del unused_kwargs 19 | super(InducedNormLinear, self).__init__() 20 | self.in_features = in_features 21 | self.out_features = out_features 22 | self.coeff = coeff 23 | self.n_iterations = n_iterations 24 | self.atol = atol 25 | self.rtol = rtol 26 | self.domain = domain 27 | self.codomain = codomain 28 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 29 | if bias: 30 | self.bias = nn.Parameter(torch.Tensor(out_features)) 31 | else: 32 | self.register_parameter('bias', None) 33 | self.reset_parameters(zero_init) 34 | 35 | with torch.no_grad(): 36 | domain, codomain = self.compute_domain_codomain() 37 | 38 | h, w = self.weight.shape 39 | self.register_buffer('scale', torch.tensor(0.)) 40 | self.register_buffer('u', normalize_u(self.weight.new_empty(h).normal_(0, 1), codomain)) 41 | self.register_buffer('v', normalize_v(self.weight.new_empty(w).normal_(0, 1), domain)) 42 | 43 | # Try different random seeds to find the best u and v. 44 | with torch.no_grad(): 45 | self.compute_weight(True, n_iterations=200, atol=None, rtol=None) 46 | best_scale = self.scale.clone() 47 | best_u, best_v = self.u.clone(), self.v.clone() 48 | if not (domain == 2 and codomain == 2): 49 | for _ in range(10): 50 | self.register_buffer('u', normalize_u(self.weight.new_empty(h).normal_(0, 1), codomain)) 51 | self.register_buffer('v', normalize_v(self.weight.new_empty(w).normal_(0, 1), domain)) 52 | self.compute_weight(True, n_iterations=200) 53 | if self.scale > best_scale: 54 | best_u, best_v = self.u.clone(), self.v.clone() 55 | self.u.copy_(best_u) 56 | self.v.copy_(best_v) 57 | 58 | def reset_parameters(self, zero_init=False): 59 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 60 | if zero_init: 61 | # normalize cannot handle zero weight in some cases. 62 | self.weight.data.div_(1000) 63 | if self.bias is not None: 64 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 65 | bound = 1 / math.sqrt(fan_in) 66 | init.uniform_(self.bias, -bound, bound) 67 | 68 | def compute_domain_codomain(self): 69 | if torch.is_tensor(self.domain): 70 | domain = asym_squash(self.domain) 71 | codomain = asym_squash(self.codomain) 72 | else: 73 | domain, codomain = self.domain, self.codomain 74 | return domain, codomain 75 | 76 | def compute_one_iter(self): 77 | domain, codomain = self.compute_domain_codomain() 78 | u = self.u.detach() 79 | v = self.v.detach() 80 | weight = self.weight.detach() 81 | u = normalize_u(torch.mv(weight, v), codomain) 82 | v = normalize_v(torch.mv(weight.t(), u), domain) 83 | return torch.dot(u, torch.mv(weight, v)) 84 | 85 | def compute_weight(self, update=True, n_iterations=None, atol=None, rtol=None): 86 | u = self.u 87 | v = self.v 88 | weight = self.weight 89 | 90 | if update: 91 | 92 | n_iterations = self.n_iterations if n_iterations is None else n_iterations 93 | atol = self.atol if atol is None else atol 94 | rtol = self.rtol if rtol is None else atol 95 | 96 | if n_iterations is None and (atol is None or rtol is None): 97 | raise ValueError('Need one of n_iteration or (atol, rtol).') 98 | 99 | max_itrs = 200 100 | if n_iterations is not None: 101 | max_itrs = n_iterations 102 | 103 | with torch.no_grad(): 104 | domain, codomain = self.compute_domain_codomain() 105 | for _ in range(max_itrs): 106 | # Algorithm from http://www.qetlab.com/InducedMatrixNorm. 107 | if n_iterations is None and atol is not None and rtol is not None: 108 | old_v = v.clone() 109 | old_u = u.clone() 110 | 111 | u = normalize_u(torch.mv(weight, v), codomain, out=u) 112 | v = normalize_v(torch.mv(weight.t(), u), domain, out=v) 113 | 114 | if n_iterations is None and atol is not None and rtol is not None: 115 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5) 116 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5) 117 | tol_u = atol + rtol * torch.max(u) 118 | tol_v = atol + rtol * torch.max(v) 119 | if err_u < tol_u and err_v < tol_v: 120 | break 121 | self.v.copy_(v) 122 | self.u.copy_(u) 123 | u = u.clone() 124 | v = v.clone() 125 | 126 | sigma = torch.dot(u, torch.mv(weight, v)) 127 | with torch.no_grad(): 128 | self.scale.copy_(sigma) 129 | # soft normalization: only when sigma larger than coeff 130 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff) 131 | weight = weight / factor 132 | return weight 133 | 134 | def forward(self, input): 135 | weight = self.compute_weight(update=False) 136 | return F.linear(input, weight, self.bias) 137 | 138 | def extra_repr(self): 139 | domain, codomain = self.compute_domain_codomain() 140 | return ( 141 | 'in_features={}, out_features={}, bias={}' 142 | ', coeff={}, domain={:.2f}, codomain={:.2f}, n_iters={}, atol={}, rtol={}, learnable_ord={}'.format( 143 | self.in_features, self.out_features, self.bias is not None, self.coeff, domain, codomain, 144 | self.n_iterations, self.atol, self.rtol, torch.is_tensor(self.domain) 145 | ) 146 | ) 147 | 148 | 149 | class InducedNormConv2d(nn.Module): 150 | 151 | def __init__( 152 | self, in_channels, out_channels, kernel_size, stride, padding, bias=True, coeff=0.97, domain=2, codomain=2, 153 | n_iterations=None, atol=None, rtol=None, **unused_kwargs 154 | ): 155 | del unused_kwargs 156 | super(InducedNormConv2d, self).__init__() 157 | self.in_channels = in_channels 158 | self.out_channels = out_channels 159 | self.kernel_size = _pair(kernel_size) 160 | self.stride = _pair(stride) 161 | self.padding = _pair(padding) 162 | self.coeff = coeff 163 | self.n_iterations = n_iterations 164 | self.domain = domain 165 | self.codomain = codomain 166 | self.atol = atol 167 | self.rtol = rtol 168 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size)) 169 | if bias: 170 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 171 | else: 172 | self.register_parameter('bias', None) 173 | self.reset_parameters() 174 | self.register_buffer('initialized', torch.tensor(0)) 175 | self.register_buffer('spatial_dims', torch.tensor([1., 1.])) 176 | self.register_buffer('scale', torch.tensor(0.)) 177 | self.register_buffer('u', self.weight.new_empty(self.out_channels)) 178 | self.register_buffer('v', self.weight.new_empty(self.in_channels)) 179 | 180 | def compute_domain_codomain(self): 181 | if torch.is_tensor(self.domain): 182 | domain = asym_squash(self.domain) 183 | codomain = asym_squash(self.codomain) 184 | else: 185 | domain, codomain = self.domain, self.codomain 186 | return domain, codomain 187 | 188 | def reset_parameters(self): 189 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 190 | if self.bias is not None: 191 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 192 | bound = 1 / math.sqrt(fan_in) 193 | init.uniform_(self.bias, -bound, bound) 194 | 195 | def _initialize_u_v(self): 196 | with torch.no_grad(): 197 | domain, codomain = self.compute_domain_codomain() 198 | if self.kernel_size == (1, 1): 199 | self.u.resize_(self.out_channels).normal_(0, 1) 200 | self.u.copy_(normalize_u(self.u, codomain)) 201 | self.v.resize_(self.in_channels).normal_(0, 1) 202 | self.v.copy_(normalize_v(self.v, domain)) 203 | else: 204 | c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item()) 205 | with torch.no_grad(): 206 | num_input_dim = c * h * w 207 | self.v.resize_(num_input_dim).normal_(0, 1) 208 | self.v.copy_(normalize_v(self.v, domain)) 209 | # forward call to infer the shape 210 | u = F.conv2d( 211 | self.v.view(1, c, h, w), self.weight, stride=self.stride, padding=self.padding, bias=None 212 | ) 213 | num_output_dim = u.shape[0] * u.shape[1] * u.shape[2] * u.shape[3] 214 | # overwrite u with random init 215 | self.u.resize_(num_output_dim).normal_(0, 1) 216 | self.u.copy_(normalize_u(self.u, codomain)) 217 | 218 | self.initialized.fill_(1) 219 | 220 | # Try different random seeds to find the best u and v. 221 | self.compute_weight(True) 222 | best_scale = self.scale.clone() 223 | best_u, best_v = self.u.clone(), self.v.clone() 224 | if not (domain == 2 and codomain == 2): 225 | for _ in range(10): 226 | if self.kernel_size == (1, 1): 227 | self.u.copy_(normalize_u(self.weight.new_empty(self.out_channels).normal_(0, 1), codomain)) 228 | self.v.copy_(normalize_v(self.weight.new_empty(self.in_channels).normal_(0, 1), domain)) 229 | else: 230 | self.u.copy_(normalize_u(torch.randn(num_output_dim).to(self.weight), codomain)) 231 | self.v.copy_(normalize_v(torch.randn(num_input_dim).to(self.weight), domain)) 232 | self.compute_weight(True, n_iterations=200) 233 | if self.scale > best_scale: 234 | best_u, best_v = self.u.clone(), self.v.clone() 235 | self.u.copy_(best_u) 236 | self.v.copy_(best_v) 237 | 238 | def compute_one_iter(self): 239 | if not self.initialized: 240 | raise ValueError('Layer needs to be initialized first.') 241 | domain, codomain = self.compute_domain_codomain() 242 | if self.kernel_size == (1, 1): 243 | u = self.u.detach() 244 | v = self.v.detach() 245 | weight = self.weight.detach().view(self.out_channels, self.in_channels) 246 | u = normalize_u(torch.mv(weight, v), codomain) 247 | v = normalize_v(torch.mv(weight.t(), u), domain) 248 | return torch.dot(u, torch.mv(weight, v)) 249 | else: 250 | u = self.u.detach() 251 | v = self.v.detach() 252 | weight = self.weight.detach() 253 | c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item()) 254 | u_s = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None) 255 | out_shape = u_s.shape 256 | u = normalize_u(u_s.view(-1), codomain) 257 | v_s = F.conv_transpose2d( 258 | u.view(out_shape), weight, stride=self.stride, padding=self.padding, output_padding=0 259 | ) 260 | v = normalize_v(v_s.view(-1), domain) 261 | weight_v = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None) 262 | return torch.dot(u.view(-1), weight_v.view(-1)) 263 | 264 | def compute_weight(self, update=True, n_iterations=None, atol=None, rtol=None): 265 | if not self.initialized: 266 | self._initialize_u_v() 267 | 268 | if self.kernel_size == (1, 1): 269 | return self._compute_weight_1x1(update, n_iterations, atol, rtol) 270 | else: 271 | return self._compute_weight_kxk(update, n_iterations, atol, rtol) 272 | 273 | def _compute_weight_1x1(self, update=True, n_iterations=None, atol=None, rtol=None): 274 | n_iterations = self.n_iterations if n_iterations is None else n_iterations 275 | atol = self.atol if atol is None else atol 276 | rtol = self.rtol if rtol is None else atol 277 | 278 | if n_iterations is None and (atol is None or rtol is None): 279 | raise ValueError('Need one of n_iteration or (atol, rtol).') 280 | 281 | max_itrs = 200 282 | if n_iterations is not None: 283 | max_itrs = n_iterations 284 | 285 | u = self.u 286 | v = self.v 287 | weight = self.weight.view(self.out_channels, self.in_channels) 288 | if update: 289 | with torch.no_grad(): 290 | domain, codomain = self.compute_domain_codomain() 291 | itrs_used = 0 292 | for _ in range(max_itrs): 293 | old_v = v.clone() 294 | old_u = u.clone() 295 | 296 | u = normalize_u(torch.mv(weight, v), codomain, out=u) 297 | v = normalize_v(torch.mv(weight.t(), u), domain, out=v) 298 | 299 | itrs_used = itrs_used + 1 300 | 301 | if n_iterations is None and atol is not None and rtol is not None: 302 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5) 303 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5) 304 | tol_u = atol + rtol * torch.max(u) 305 | tol_v = atol + rtol * torch.max(v) 306 | if err_u < tol_u and err_v < tol_v: 307 | break 308 | if itrs_used > 0: 309 | if domain != 1 and domain != 2: 310 | self.v.copy_(v) 311 | if codomain != 2 and codomain != float('inf'): 312 | self.u.copy_(u) 313 | u = u.clone() 314 | v = v.clone() 315 | 316 | sigma = torch.dot(u, torch.mv(weight, v)) 317 | with torch.no_grad(): 318 | self.scale.copy_(sigma) 319 | # soft normalization: only when sigma larger than coeff 320 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff) 321 | weight = weight / factor 322 | return weight.view(self.out_channels, self.in_channels, 1, 1) 323 | 324 | def _compute_weight_kxk(self, update=True, n_iterations=None, atol=None, rtol=None): 325 | n_iterations = self.n_iterations if n_iterations is None else n_iterations 326 | atol = self.atol if atol is None else atol 327 | rtol = self.rtol if rtol is None else atol 328 | 329 | if n_iterations is None and (atol is None or rtol is None): 330 | raise ValueError('Need one of n_iteration or (atol, rtol).') 331 | 332 | max_itrs = 200 333 | if n_iterations is not None: 334 | max_itrs = n_iterations 335 | 336 | u = self.u 337 | v = self.v 338 | weight = self.weight 339 | c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item()) 340 | if update: 341 | with torch.no_grad(): 342 | domain, codomain = self.compute_domain_codomain() 343 | itrs_used = 0 344 | for _ in range(max_itrs): 345 | old_u = u.clone() 346 | old_v = v.clone() 347 | 348 | u_s = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None) 349 | out_shape = u_s.shape 350 | u = normalize_u(u_s.view(-1), codomain, out=u) 351 | 352 | v_s = F.conv_transpose2d( 353 | u.view(out_shape), weight, stride=self.stride, padding=self.padding, output_padding=0 354 | ) 355 | v = normalize_v(v_s.view(-1), domain, out=v) 356 | 357 | itrs_used = itrs_used + 1 358 | if n_iterations is None and atol is not None and rtol is not None: 359 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5) 360 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5) 361 | tol_u = atol + rtol * torch.max(u) 362 | tol_v = atol + rtol * torch.max(v) 363 | if err_u < tol_u and err_v < tol_v: 364 | break 365 | if itrs_used > 0: 366 | if domain != 2: 367 | self.v.copy_(v) 368 | if codomain != 2: 369 | self.u.copy_(u) 370 | v = v.clone() 371 | u = u.clone() 372 | 373 | weight_v = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None) 374 | weight_v = weight_v.view(-1) 375 | sigma = torch.dot(u.view(-1), weight_v) 376 | with torch.no_grad(): 377 | self.scale.copy_(sigma) 378 | # soft normalization: only when sigma larger than coeff 379 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff) 380 | weight = weight / factor 381 | return weight 382 | 383 | def forward(self, input): 384 | if not self.initialized: self.spatial_dims.copy_(torch.tensor(input.shape[2:4]).to(self.spatial_dims)) 385 | weight = self.compute_weight(update=False) 386 | return F.conv2d(input, weight, self.bias, self.stride, self.padding, 1, 1) 387 | 388 | def extra_repr(self): 389 | domain, codomain = self.compute_domain_codomain() 390 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') 391 | if self.padding != (0,) * len(self.padding): 392 | s += ', padding={padding}' 393 | if self.bias is None: 394 | s += ', bias=False' 395 | s += ', coeff={}, domain={:.2f}, codomain={:.2f}, n_iters={}, atol={}, rtol={}, learnable_ord={}'.format( 396 | self.coeff, domain, codomain, self.n_iterations, self.atol, self.rtol, torch.is_tensor(self.domain) 397 | ) 398 | return s.format(**self.__dict__) 399 | 400 | 401 | def projmax_(v): 402 | """Inplace argmax on absolute value.""" 403 | ind = torch.argmax(torch.abs(v)) 404 | v.zero_() 405 | v[ind] = 1 406 | return v 407 | 408 | 409 | def normalize_v(v, domain, out=None): 410 | if not torch.is_tensor(domain) and domain == 2: 411 | v = F.normalize(v, p=2, dim=0, out=out) 412 | elif domain == 1: 413 | v = projmax_(v) 414 | else: 415 | vabs = torch.abs(v) 416 | vph = v / vabs 417 | vph[torch.isnan(vph)] = 1 418 | vabs = vabs / torch.max(vabs) 419 | vabs = vabs**(1 / (domain - 1)) 420 | v = vph * vabs / vector_norm(vabs, domain) 421 | return v 422 | 423 | 424 | def normalize_u(u, codomain, out=None): 425 | if not torch.is_tensor(codomain) and codomain == 2: 426 | u = F.normalize(u, p=2, dim=0, out=out) 427 | elif codomain == float('inf'): 428 | u = projmax_(u) 429 | else: 430 | uabs = torch.abs(u) 431 | uph = u / uabs 432 | uph[torch.isnan(uph)] = 1 433 | uabs = uabs / torch.max(uabs) 434 | uabs = uabs**(codomain - 1) 435 | if codomain == 1: 436 | u = uph * uabs / vector_norm(uabs, float('inf')) 437 | else: 438 | u = uph * uabs / vector_norm(uabs, codomain / (codomain - 1)) 439 | return u 440 | 441 | 442 | def vector_norm(x, p): 443 | x = x.view(-1) 444 | return torch.sum(x**p)**(1 / p) 445 | 446 | 447 | def leaky_elu(x, a=0.3): 448 | return a * x + (1 - a) * F.elu(x) 449 | 450 | 451 | def asym_squash(x): 452 | return torch.tanh(-leaky_elu(-x + 0.5493061829986572)) * 2 + 3 453 | 454 | 455 | # def asym_squash(x): 456 | # return torch.tanh(x) / 2. + 2. 457 | 458 | 459 | def _ntuple(n): 460 | 461 | def parse(x): 462 | if isinstance(x, container_abcs.Iterable): 463 | return x 464 | return tuple(repeat(x, n)) 465 | 466 | return parse 467 | 468 | 469 | _single = _ntuple(1) 470 | _pair = _ntuple(2) 471 | _triple = _ntuple(3) 472 | _quadruple = _ntuple(4) 473 | 474 | if __name__ == '__main__': 475 | 476 | p = nn.Parameter(torch.tensor(2.1)) 477 | 478 | m = InducedNormConv2d(10, 2, 3, 1, 1, atol=1e-3, rtol=1e-3, domain=p, codomain=p) 479 | W = m.compute_weight() 480 | 481 | m.compute_one_iter().backward() 482 | print(p.grad) 483 | 484 | # m.weight.data.copy_(W) 485 | # W = m.compute_weight().cpu().detach().numpy() 486 | # import numpy as np 487 | # print( 488 | # '{} {} {}'.format( 489 | # np.linalg.norm(W, ord=2, axis=(0, 1)), 490 | # '>' if np.linalg.norm(W, ord=2, axis=(0, 1)) > m.scale else '<', 491 | # m.scale, 492 | # ) 493 | # ) 494 | -------------------------------------------------------------------------------- /resflows/layers/base/utils.py: -------------------------------------------------------------------------------- 1 | import collections.abc as container_abcs 2 | from itertools import repeat 3 | 4 | 5 | def _ntuple(n): 6 | 7 | def parse(x): 8 | if isinstance(x, container_abcs.Iterable): 9 | return x 10 | return tuple(repeat(x, n)) 11 | 12 | return parse 13 | 14 | 15 | _single = _ntuple(1) 16 | _pair = _ntuple(2) 17 | _triple = _ntuple(3) 18 | _quadruple = _ntuple(4) 19 | -------------------------------------------------------------------------------- /resflows/layers/container.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SequentialFlow(nn.Module): 5 | """A generalized nn.Sequential container for normalizing flows. 6 | """ 7 | 8 | def __init__(self, layersList): 9 | super(SequentialFlow, self).__init__() 10 | self.chain = nn.ModuleList(layersList) 11 | 12 | def forward(self, x, logpx=None): 13 | if logpx is None: 14 | for i in range(len(self.chain)): 15 | x = self.chain[i](x) 16 | return x 17 | else: 18 | for i in range(len(self.chain)): 19 | x, logpx = self.chain[i](x, logpx) 20 | return x, logpx 21 | 22 | def inverse(self, y, logpy=None): 23 | if logpy is None: 24 | for i in range(len(self.chain) - 1, -1, -1): 25 | y = self.chain[i].inverse(y) 26 | return y 27 | else: 28 | for i in range(len(self.chain) - 1, -1, -1): 29 | y, logpy = self.chain[i].inverse(y, logpy) 30 | return y, logpy 31 | 32 | 33 | class Inverse(nn.Module): 34 | 35 | def __init__(self, flow): 36 | super(Inverse, self).__init__() 37 | self.flow = flow 38 | 39 | def forward(self, x, logpx=None): 40 | return self.flow.inverse(x, logpx) 41 | 42 | def inverse(self, y, logpy=None): 43 | return self.flow.forward(y, logpy) 44 | -------------------------------------------------------------------------------- /resflows/layers/coupling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import mask_utils 4 | 5 | __all__ = ['CouplingBlock', 'ChannelCouplingBlock', 'MaskedCouplingBlock'] 6 | 7 | 8 | class CouplingBlock(nn.Module): 9 | """Basic coupling layer for Tensors of shape (n,d). 10 | 11 | Forward computation: 12 | y_a = x_a 13 | y_b = y_b * exp(s(x_a)) + t(x_a) 14 | Inverse computation: 15 | x_a = y_a 16 | x_b = (y_b - t(y_a)) * exp(-s(y_a)) 17 | """ 18 | 19 | def __init__(self, dim, nnet, swap=False): 20 | """ 21 | Args: 22 | s (nn.Module) 23 | t (nn.Module) 24 | """ 25 | super(CouplingBlock, self).__init__() 26 | assert (dim % 2 == 0) 27 | self.d = dim // 2 28 | self.nnet = nnet 29 | self.swap = swap 30 | 31 | def func_s_t(self, x): 32 | f = self.nnet(x) 33 | s = f[:, :self.d] 34 | t = f[:, self.d:] 35 | return s, t 36 | 37 | def forward(self, x, logpx=None): 38 | """Forward computation of a simple coupling split on the axis=1. 39 | """ 40 | x_a = x[:, :self.d] if not self.swap else x[:, self.d:] 41 | x_b = x[:, self.d:] if not self.swap else x[:, :self.d] 42 | y_a, y_b, logdetgrad = self._forward_computation(x_a, x_b) 43 | y = [y_a, y_b] if not self.swap else [y_b, y_a] 44 | 45 | if logpx is None: 46 | return torch.cat(y, dim=1) 47 | else: 48 | return torch.cat(y, dim=1), logpx - logdetgrad.view(x.size(0), -1).sum(1, keepdim=True) 49 | 50 | def inverse(self, y, logpy=None): 51 | """Inverse computation of a simple coupling split on the axis=1. 52 | """ 53 | y_a = y[:, :self.d] if not self.swap else y[:, self.d:] 54 | y_b = y[:, self.d:] if not self.swap else y[:, :self.d] 55 | x_a, x_b, logdetgrad = self._inverse_computation(y_a, y_b) 56 | x = [x_a, x_b] if not self.swap else [x_b, x_a] 57 | if logpy is None: 58 | return torch.cat(x, dim=1) 59 | else: 60 | return torch.cat(x, dim=1), logpy + logdetgrad 61 | 62 | def _forward_computation(self, x_a, x_b): 63 | y_a = x_a 64 | s_a, t_a = self.func_s_t(x_a) 65 | scale = torch.sigmoid(s_a + 2.) 66 | y_b = x_b * scale + t_a 67 | logdetgrad = self._logdetgrad(scale) 68 | return y_a, y_b, logdetgrad 69 | 70 | def _inverse_computation(self, y_a, y_b): 71 | x_a = y_a 72 | s_a, t_a = self.func_s_t(y_a) 73 | scale = torch.sigmoid(s_a + 2.) 74 | x_b = (y_b - t_a) / scale 75 | logdetgrad = self._logdetgrad(scale) 76 | return x_a, x_b, logdetgrad 77 | 78 | def _logdetgrad(self, scale): 79 | """ 80 | Returns: 81 | Tensor (N, 1): containing ln |det J| where J is the jacobian 82 | """ 83 | return torch.log(scale).view(scale.shape[0], -1).sum(1, keepdim=True) 84 | 85 | def extra_repr(self): 86 | return 'dim={d}, swap={swap}'.format(**self.__dict__) 87 | 88 | 89 | class ChannelCouplingBlock(CouplingBlock): 90 | """Channel-wise coupling layer for images. 91 | """ 92 | 93 | def __init__(self, dim, nnet, mask_type='channel0'): 94 | if mask_type == 'channel0': 95 | swap = False 96 | elif mask_type == 'channel1': 97 | swap = True 98 | else: 99 | raise ValueError('Unknown mask type.') 100 | super(ChannelCouplingBlock, self).__init__(dim, nnet, swap) 101 | self.mask_type = mask_type 102 | 103 | def extra_repr(self): 104 | return 'dim={d}, mask_type={mask_type}'.format(**self.__dict__) 105 | 106 | 107 | class MaskedCouplingBlock(nn.Module): 108 | """Coupling layer for images implemented using masks. 109 | """ 110 | 111 | def __init__(self, dim, nnet, mask_type='checkerboard0'): 112 | nn.Module.__init__(self) 113 | self.d = dim 114 | self.nnet = nnet 115 | self.mask_type = mask_type 116 | 117 | def func_s_t(self, x): 118 | f = self.nnet(x) 119 | s = torch.sigmoid(f[:, :self.d] + 2.) 120 | t = f[:, self.d:] 121 | return s, t 122 | 123 | def forward(self, x, logpx=None): 124 | # get mask 125 | b = mask_utils.get_mask(x, mask_type=self.mask_type) 126 | 127 | # masked forward 128 | x_a = b * x 129 | s, t = self.func_s_t(x_a) 130 | y = (x * s + t) * (1 - b) + x_a 131 | 132 | if logpx is None: 133 | return y 134 | else: 135 | return y, logpx - self._logdetgrad(s, b) 136 | 137 | def inverse(self, y, logpy=None): 138 | # get mask 139 | b = mask_utils.get_mask(y, mask_type=self.mask_type) 140 | 141 | # masked forward 142 | y_a = b * y 143 | s, t = self.func_s_t(y_a) 144 | x = y_a + (1 - b) * (y - t) / s 145 | 146 | if logpy is None: 147 | return x 148 | else: 149 | return x, logpy + self._logdetgrad(s, b) 150 | 151 | def _logdetgrad(self, s, mask): 152 | return torch.log(s).mul_(1 - mask).view(s.shape[0], -1).sum(1, keepdim=True) 153 | 154 | def extra_repr(self): 155 | return 'dim={d}, mask_type={mask_type}'.format(**self.__dict__) 156 | -------------------------------------------------------------------------------- /resflows/layers/elemwise.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | _DEFAULT_ALPHA = 1e-6 6 | 7 | 8 | class ZeroMeanTransform(nn.Module): 9 | 10 | def __init__(self): 11 | nn.Module.__init__(self) 12 | 13 | def forward(self, x, logpx=None): 14 | x = x - .5 15 | if logpx is None: 16 | return x 17 | return x, logpx 18 | 19 | def inverse(self, y, logpy=None): 20 | y = y + .5 21 | if logpy is None: 22 | return y 23 | return y, logpy 24 | 25 | 26 | class Normalize(nn.Module): 27 | 28 | def __init__(self, mean, std): 29 | nn.Module.__init__(self) 30 | self.register_buffer('mean', torch.as_tensor(mean, dtype=torch.float32)) 31 | self.register_buffer('std', torch.as_tensor(std, dtype=torch.float32)) 32 | 33 | def forward(self, x, logpx=None): 34 | y = x.clone() 35 | c = len(self.mean) 36 | y[:, :c].sub_(self.mean[None, :, None, None]).div_(self.std[None, :, None, None]) 37 | if logpx is None: 38 | return y 39 | else: 40 | return y, logpx - self._logdetgrad(x) 41 | 42 | def inverse(self, y, logpy=None): 43 | x = y.clone() 44 | c = len(self.mean) 45 | x[:, :c].mul_(self.std[None, :, None, None]).add_(self.mean[None, :, None, None]) 46 | if logpy is None: 47 | return x 48 | else: 49 | return x, logpy + self._logdetgrad(x) 50 | 51 | def _logdetgrad(self, x): 52 | logdetgrad = ( 53 | self.std.abs().log().mul_(-1).view(1, -1, 1, 1).expand(x.shape[0], len(self.std), x.shape[2], x.shape[3]) 54 | ) 55 | return logdetgrad.reshape(x.shape[0], -1).sum(-1, keepdim=True) 56 | 57 | 58 | class LogitTransform(nn.Module): 59 | """ 60 | The proprocessing step used in Real NVP: 61 | y = sigmoid(x) - a / (1 - 2a) 62 | x = logit(a + (1 - 2a)*y) 63 | """ 64 | 65 | def __init__(self, alpha=_DEFAULT_ALPHA): 66 | nn.Module.__init__(self) 67 | self.alpha = alpha 68 | 69 | def forward(self, x, logpx=None): 70 | s = self.alpha + (1 - 2 * self.alpha) * x 71 | y = torch.log(s) - torch.log(1 - s) 72 | if logpx is None: 73 | return y 74 | return y, logpx - self._logdetgrad(x).view(x.size(0), -1).sum(1, keepdim=True) 75 | 76 | def inverse(self, y, logpy=None): 77 | x = (torch.sigmoid(y) - self.alpha) / (1 - 2 * self.alpha) 78 | if logpy is None: 79 | return x 80 | return x, logpy + self._logdetgrad(x).view(x.size(0), -1).sum(1, keepdim=True) 81 | 82 | def _logdetgrad(self, x): 83 | s = self.alpha + (1 - 2 * self.alpha) * x 84 | logdetgrad = -torch.log(s - s * s) + math.log(1 - 2 * self.alpha) 85 | return logdetgrad 86 | 87 | def __repr__(self): 88 | return ('{name}({alpha})'.format(name=self.__class__.__name__, **self.__dict__)) 89 | -------------------------------------------------------------------------------- /resflows/layers/glow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class InvertibleLinear(nn.Module): 7 | 8 | def __init__(self, dim): 9 | super(InvertibleLinear, self).__init__() 10 | self.dim = dim 11 | self.weight = nn.Parameter(torch.eye(dim)[torch.randperm(dim)]) 12 | 13 | def forward(self, x, logpx=None): 14 | y = F.linear(x, self.weight) 15 | if logpx is None: 16 | return y 17 | else: 18 | return y, logpx - self._logdetgrad 19 | 20 | def inverse(self, y, logpy=None): 21 | x = F.linear(y, self.weight.inverse()) 22 | if logpy is None: 23 | return x 24 | else: 25 | return x, logpy + self._logdetgrad 26 | 27 | @property 28 | def _logdetgrad(self): 29 | return torch.log(torch.abs(torch.det(self.weight))) 30 | 31 | def extra_repr(self): 32 | return 'dim={}'.format(self.dim) 33 | 34 | 35 | class InvertibleConv2d(nn.Module): 36 | 37 | def __init__(self, dim): 38 | super(InvertibleConv2d, self).__init__() 39 | self.dim = dim 40 | self.weight = nn.Parameter(torch.eye(dim)[torch.randperm(dim)]) 41 | 42 | def forward(self, x, logpx=None): 43 | y = F.conv2d(x, self.weight.view(self.dim, self.dim, 1, 1)) 44 | if logpx is None: 45 | return y 46 | else: 47 | return y, logpx - self._logdetgrad.expand_as(logpx) * x.shape[2] * x.shape[3] 48 | 49 | def inverse(self, y, logpy=None): 50 | x = F.conv2d(y, self.weight.inverse().view(self.dim, self.dim, 1, 1)) 51 | if logpy is None: 52 | return x 53 | else: 54 | return x, logpy + self._logdetgrad.expand_as(logpy) * x.shape[2] * x.shape[3] 55 | 56 | @property 57 | def _logdetgrad(self): 58 | return torch.log(torch.abs(torch.det(self.weight))) 59 | 60 | def extra_repr(self): 61 | return 'dim={}'.format(self.dim) 62 | -------------------------------------------------------------------------------- /resflows/layers/iresblock.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | import logging 7 | 8 | logger = logging.getLogger() 9 | 10 | __all__ = ['iResBlock'] 11 | 12 | 13 | class iResBlock(nn.Module): 14 | 15 | def __init__( 16 | self, 17 | nnet, 18 | geom_p=0.5, 19 | lamb=2., 20 | n_power_series=None, 21 | exact_trace=False, 22 | brute_force=False, 23 | n_samples=1, 24 | n_exact_terms=2, 25 | n_dist='geometric', 26 | neumann_grad=True, 27 | grad_in_forward=False, 28 | ): 29 | """ 30 | Args: 31 | nnet: a nn.Module 32 | n_power_series: number of power series. If not None, uses a biased approximation to logdet. 33 | exact_trace: if False, uses a Hutchinson trace estimator. Otherwise computes the exact full Jacobian. 34 | brute_force: Computes the exact logdet. Only available for 2D inputs. 35 | """ 36 | nn.Module.__init__(self) 37 | self.nnet = nnet 38 | self.n_dist = n_dist 39 | self.geom_p = nn.Parameter(torch.tensor(np.log(geom_p) - np.log(1. - geom_p))) 40 | self.lamb = nn.Parameter(torch.tensor(lamb)) 41 | self.n_samples = n_samples 42 | self.n_power_series = n_power_series 43 | self.exact_trace = exact_trace 44 | self.brute_force = brute_force 45 | self.n_exact_terms = n_exact_terms 46 | self.grad_in_forward = grad_in_forward 47 | self.neumann_grad = neumann_grad 48 | 49 | # store the samples of n. 50 | self.register_buffer('last_n_samples', torch.zeros(self.n_samples)) 51 | self.register_buffer('last_firmom', torch.zeros(1)) 52 | self.register_buffer('last_secmom', torch.zeros(1)) 53 | 54 | def forward(self, x, logpx=None): 55 | if logpx is None: 56 | y = x + self.nnet(x) 57 | return y 58 | else: 59 | g, logdetgrad = self._logdetgrad(x) 60 | return x + g, logpx - logdetgrad 61 | 62 | def inverse(self, y, logpy=None): 63 | x = self._inverse_fixed_point(y) 64 | if logpy is None: 65 | return x 66 | else: 67 | return x, logpy + self._logdetgrad(x)[1] 68 | 69 | def _inverse_fixed_point(self, y, atol=1e-5, rtol=1e-5): 70 | x, x_prev = y - self.nnet(y), y 71 | i = 0 72 | tol = atol + y.abs() * rtol 73 | while not torch.all((x - x_prev)**2 / tol < 1): 74 | x, x_prev = y - self.nnet(x), x 75 | i += 1 76 | if i > 1000: 77 | logger.info('Iterations exceeded 1000 for inverse.') 78 | break 79 | return x 80 | 81 | def _logdetgrad(self, x): 82 | """Returns g(x) and logdet|d(x+g(x))/dx|.""" 83 | 84 | with torch.enable_grad(): 85 | if (self.brute_force or not self.training) and (x.ndimension() == 2 and x.shape[1] == 2): 86 | ########################################### 87 | # Brute-force compute Jacobian determinant. 88 | ########################################### 89 | x = x.requires_grad_(True) 90 | g = self.nnet(x) 91 | # Brute-force logdet only available for 2D. 92 | jac = batch_jacobian(g, x) 93 | batch_dets = (jac[:, 0, 0] + 1) * (jac[:, 1, 1] + 1) - jac[:, 0, 1] * jac[:, 1, 0] 94 | return g, torch.log(torch.abs(batch_dets)).view(-1, 1) 95 | 96 | if self.n_dist == 'geometric': 97 | geom_p = torch.sigmoid(self.geom_p).item() 98 | sample_fn = lambda m: geometric_sample(geom_p, m) 99 | rcdf_fn = lambda k, offset: geometric_1mcdf(geom_p, k, offset) 100 | elif self.n_dist == 'poisson': 101 | lamb = self.lamb.item() 102 | sample_fn = lambda m: poisson_sample(lamb, m) 103 | rcdf_fn = lambda k, offset: poisson_1mcdf(lamb, k, offset) 104 | 105 | if self.training: 106 | if self.n_power_series is None: 107 | # Unbiased estimation. 108 | lamb = self.lamb.item() 109 | n_samples = sample_fn(self.n_samples) 110 | n_power_series = max(n_samples) + self.n_exact_terms 111 | coeff_fn = lambda k: 1 / rcdf_fn(k, self.n_exact_terms) * \ 112 | sum(n_samples >= k - self.n_exact_terms) / len(n_samples) 113 | else: 114 | # Truncated estimation. 115 | n_power_series = self.n_power_series 116 | coeff_fn = lambda k: 1. 117 | else: 118 | # Unbiased estimation with more exact terms. 119 | lamb = self.lamb.item() 120 | n_samples = sample_fn(self.n_samples) 121 | n_power_series = max(n_samples) + 20 122 | coeff_fn = lambda k: 1 / rcdf_fn(k, 20) * \ 123 | sum(n_samples >= k - 20) / len(n_samples) 124 | 125 | if not self.exact_trace: 126 | #################################### 127 | # Power series with trace estimator. 128 | #################################### 129 | vareps = torch.randn_like(x) 130 | 131 | # Choose the type of estimator. 132 | if self.training and self.neumann_grad: 133 | estimator_fn = neumann_logdet_estimator 134 | else: 135 | estimator_fn = basic_logdet_estimator 136 | 137 | # Do backprop-in-forward to save memory. 138 | if self.training and self.grad_in_forward: 139 | g, logdetgrad = mem_eff_wrapper( 140 | estimator_fn, self.nnet, x, n_power_series, vareps, coeff_fn, self.training 141 | ) 142 | else: 143 | x = x.requires_grad_(True) 144 | g = self.nnet(x) 145 | logdetgrad = estimator_fn(g, x, n_power_series, vareps, coeff_fn, self.training) 146 | else: 147 | ############################################ 148 | # Power series with exact trace computation. 149 | ############################################ 150 | x = x.requires_grad_(True) 151 | g = self.nnet(x) 152 | jac = batch_jacobian(g, x) 153 | logdetgrad = batch_trace(jac) 154 | jac_k = jac 155 | for k in range(2, n_power_series + 1): 156 | jac_k = torch.bmm(jac, jac_k) 157 | logdetgrad = logdetgrad + (-1)**(k + 1) / k * coeff_fn(k) * batch_trace(jac_k) 158 | 159 | if self.training and self.n_power_series is None: 160 | self.last_n_samples.copy_(torch.tensor(n_samples).to(self.last_n_samples)) 161 | estimator = logdetgrad.detach() 162 | self.last_firmom.copy_(torch.mean(estimator).to(self.last_firmom)) 163 | self.last_secmom.copy_(torch.mean(estimator**2).to(self.last_secmom)) 164 | return g, logdetgrad.view(-1, 1) 165 | 166 | def extra_repr(self): 167 | return 'dist={}, n_samples={}, n_power_series={}, neumann_grad={}, exact_trace={}, brute_force={}'.format( 168 | self.n_dist, self.n_samples, self.n_power_series, self.neumann_grad, self.exact_trace, self.brute_force 169 | ) 170 | 171 | 172 | def batch_jacobian(g, x): 173 | jac = [] 174 | for d in range(g.shape[1]): 175 | jac.append(torch.autograd.grad(torch.sum(g[:, d]), x, create_graph=True)[0].view(x.shape[0], 1, x.shape[1])) 176 | return torch.cat(jac, 1) 177 | 178 | 179 | def batch_trace(M): 180 | return M.view(M.shape[0], -1)[:, ::M.shape[1] + 1].sum(1) 181 | 182 | 183 | ##################### 184 | # Logdet Estimators 185 | ##################### 186 | class MemoryEfficientLogDetEstimator(torch.autograd.Function): 187 | 188 | @staticmethod 189 | def forward(ctx, estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *g_params): 190 | ctx.training = training 191 | with torch.enable_grad(): 192 | x = x.detach().requires_grad_(True) 193 | g = gnet(x) 194 | ctx.g = g 195 | ctx.x = x 196 | logdetgrad = estimator_fn(g, x, n_power_series, vareps, coeff_fn, training) 197 | 198 | if training: 199 | grad_x, *grad_params = torch.autograd.grad( 200 | logdetgrad.sum(), (x,) + g_params, retain_graph=True, allow_unused=True 201 | ) 202 | if grad_x is None: 203 | grad_x = torch.zeros_like(x) 204 | ctx.save_for_backward(grad_x, *g_params, *grad_params) 205 | 206 | return safe_detach(g), safe_detach(logdetgrad) 207 | 208 | @staticmethod 209 | def backward(ctx, grad_g, grad_logdetgrad): 210 | training = ctx.training 211 | if not training: 212 | raise ValueError('Provide training=True if using backward.') 213 | 214 | with torch.enable_grad(): 215 | grad_x, *params_and_grad = ctx.saved_tensors 216 | g, x = ctx.g, ctx.x 217 | 218 | # Precomputed gradients. 219 | g_params = params_and_grad[:len(params_and_grad) // 2] 220 | grad_params = params_and_grad[len(params_and_grad) // 2:] 221 | 222 | dg_x, *dg_params = torch.autograd.grad(g, [x] + g_params, grad_g, allow_unused=True) 223 | 224 | # Update based on gradient from logdetgrad. 225 | dL = grad_logdetgrad[0].detach() 226 | with torch.no_grad(): 227 | grad_x.mul_(dL) 228 | grad_params = tuple([g.mul_(dL) if g is not None else None for g in grad_params]) 229 | 230 | # Update based on gradient from g. 231 | with torch.no_grad(): 232 | grad_x.add_(dg_x) 233 | grad_params = tuple([dg.add_(djac) if djac is not None else dg for dg, djac in zip(dg_params, grad_params)]) 234 | 235 | return (None, None, grad_x, None, None, None, None) + grad_params 236 | 237 | 238 | def basic_logdet_estimator(g, x, n_power_series, vareps, coeff_fn, training): 239 | vjp = vareps 240 | logdetgrad = torch.tensor(0.).to(x) 241 | for k in range(1, n_power_series + 1): 242 | vjp = torch.autograd.grad(g, x, vjp, create_graph=training, retain_graph=True)[0] 243 | tr = torch.sum(vjp.view(x.shape[0], -1) * vareps.view(x.shape[0], -1), 1) 244 | delta = (-1)**(k + 1) / k * coeff_fn(k) * tr 245 | logdetgrad = logdetgrad + delta 246 | return logdetgrad 247 | 248 | 249 | def neumann_logdet_estimator(g, x, n_power_series, vareps, coeff_fn, training): 250 | vjp = vareps 251 | neumann_vjp = vareps 252 | with torch.no_grad(): 253 | for k in range(1, n_power_series + 1): 254 | vjp = torch.autograd.grad(g, x, vjp, retain_graph=True)[0] 255 | neumann_vjp = neumann_vjp + (-1)**k * coeff_fn(k) * vjp 256 | vjp_jac = torch.autograd.grad(g, x, neumann_vjp, create_graph=training)[0] 257 | logdetgrad = torch.sum(vjp_jac.view(x.shape[0], -1) * vareps.view(x.shape[0], -1), 1) 258 | return logdetgrad 259 | 260 | 261 | def mem_eff_wrapper(estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training): 262 | 263 | # We need this in order to access the variables inside this module, 264 | # since we have no other way of getting variables along the execution path. 265 | if not isinstance(gnet, nn.Module): 266 | raise ValueError('g is required to be an instance of nn.Module.') 267 | 268 | return MemoryEfficientLogDetEstimator.apply( 269 | estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *list(gnet.parameters()) 270 | ) 271 | 272 | 273 | # -------- Helper distribution functions -------- 274 | # These take python ints or floats, not PyTorch tensors. 275 | 276 | 277 | def geometric_sample(p, n_samples): 278 | return np.random.geometric(p, n_samples) 279 | 280 | 281 | def geometric_1mcdf(p, k, offset): 282 | if k <= offset: 283 | return 1. 284 | else: 285 | k = k - offset 286 | """P(n >= k)""" 287 | return (1 - p)**max(k - 1, 0) 288 | 289 | 290 | def poisson_sample(lamb, n_samples): 291 | return np.random.poisson(lamb, n_samples) 292 | 293 | 294 | def poisson_1mcdf(lamb, k, offset): 295 | if k <= offset: 296 | return 1. 297 | else: 298 | k = k - offset 299 | """P(n >= k)""" 300 | sum = 1. 301 | for i in range(1, k): 302 | sum += lamb**i / math.factorial(i) 303 | return 1 - np.exp(-lamb) * sum 304 | 305 | 306 | def sample_rademacher_like(y): 307 | return torch.randint(low=0, high=2, size=y.shape).to(y) * 2 - 1 308 | 309 | 310 | # -------------- Helper functions -------------- 311 | 312 | 313 | def safe_detach(tensor): 314 | return tensor.detach().requires_grad_(tensor.requires_grad) 315 | 316 | 317 | def _flatten(sequence): 318 | flat = [p.reshape(-1) for p in sequence] 319 | return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) 320 | 321 | 322 | def _flatten_convert_none_to_zeros(sequence, like_sequence): 323 | flat = [p.reshape(-1) if p is not None else torch.zeros_like(q).view(-1) for p, q in zip(sequence, like_sequence)] 324 | return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) 325 | -------------------------------------------------------------------------------- /resflows/layers/mask_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _get_checkerboard_mask(x, swap=False): 5 | n, c, h, w = x.size() 6 | 7 | H = ((h - 1) // 2 + 1) * 2 # H = h + 1 if h is odd and h if h is even 8 | W = ((w - 1) // 2 + 1) * 2 9 | 10 | # construct checkerboard mask 11 | if not swap: 12 | mask = torch.Tensor([[1, 0], [0, 1]]).repeat(H // 2, W // 2) 13 | else: 14 | mask = torch.Tensor([[0, 1], [1, 0]]).repeat(H // 2, W // 2) 15 | mask = mask[:h, :w] 16 | mask = mask.contiguous().view(1, 1, h, w).expand(n, c, h, w).type_as(x.data) 17 | 18 | return mask 19 | 20 | 21 | def _get_channel_mask(x, swap=False): 22 | n, c, h, w = x.size() 23 | assert (c % 2 == 0) 24 | 25 | # construct channel-wise mask 26 | mask = torch.zeros(x.size()) 27 | if not swap: 28 | mask[:, :c // 2] = 1 29 | else: 30 | mask[:, c // 2:] = 1 31 | return mask 32 | 33 | 34 | def get_mask(x, mask_type=None): 35 | if mask_type is None: 36 | return torch.zeros(x.size()).to(x) 37 | elif mask_type == 'channel0': 38 | return _get_channel_mask(x, swap=False) 39 | elif mask_type == 'channel1': 40 | return _get_channel_mask(x, swap=True) 41 | elif mask_type == 'checkerboard0': 42 | return _get_checkerboard_mask(x, swap=False) 43 | elif mask_type == 'checkerboard1': 44 | return _get_checkerboard_mask(x, swap=True) 45 | else: 46 | raise ValueError('Unknown mask type {}'.format(mask_type)) 47 | -------------------------------------------------------------------------------- /resflows/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | 5 | __all__ = ['MovingBatchNorm1d', 'MovingBatchNorm2d'] 6 | 7 | 8 | class MovingBatchNormNd(nn.Module): 9 | 10 | def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True): 11 | super(MovingBatchNormNd, self).__init__() 12 | self.num_features = num_features 13 | self.affine = affine 14 | self.eps = eps 15 | self.decay = decay 16 | self.bn_lag = bn_lag 17 | self.register_buffer('step', torch.zeros(1)) 18 | if self.affine: 19 | self.bias = Parameter(torch.Tensor(num_features)) 20 | else: 21 | self.register_parameter('bias', None) 22 | self.register_buffer('running_mean', torch.zeros(num_features)) 23 | self.reset_parameters() 24 | 25 | @property 26 | def shape(self): 27 | raise NotImplementedError 28 | 29 | def reset_parameters(self): 30 | self.running_mean.zero_() 31 | if self.affine: 32 | self.bias.data.zero_() 33 | 34 | def forward(self, x, logpx=None): 35 | c = x.size(1) 36 | used_mean = self.running_mean.clone().detach() 37 | 38 | if self.training: 39 | # compute batch statistics 40 | x_t = x.transpose(0, 1).contiguous().view(c, -1) 41 | batch_mean = torch.mean(x_t, dim=1) 42 | 43 | # moving average 44 | if self.bn_lag > 0: 45 | used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach()) 46 | used_mean /= (1. - self.bn_lag**(self.step[0] + 1)) 47 | 48 | # update running estimates 49 | self.running_mean -= self.decay * (self.running_mean - batch_mean.data) 50 | self.step += 1 51 | 52 | # perform normalization 53 | used_mean = used_mean.view(*self.shape).expand_as(x) 54 | 55 | y = x - used_mean 56 | 57 | if self.affine: 58 | bias = self.bias.view(*self.shape).expand_as(x) 59 | y = y + bias 60 | 61 | if logpx is None: 62 | return y 63 | else: 64 | return y, logpx 65 | 66 | def inverse(self, y, logpy=None): 67 | used_mean = self.running_mean 68 | 69 | if self.affine: 70 | bias = self.bias.view(*self.shape).expand_as(y) 71 | y = y - bias 72 | 73 | used_mean = used_mean.view(*self.shape).expand_as(y) 74 | x = y + used_mean 75 | 76 | if logpy is None: 77 | return x 78 | else: 79 | return x, logpy 80 | 81 | def __repr__(self): 82 | return ( 83 | '{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},' 84 | ' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__) 85 | ) 86 | 87 | 88 | class MovingBatchNorm1d(MovingBatchNormNd): 89 | 90 | @property 91 | def shape(self): 92 | return [1, -1] 93 | 94 | 95 | class MovingBatchNorm2d(MovingBatchNormNd): 96 | 97 | @property 98 | def shape(self): 99 | return [1, -1, 1, 1] 100 | -------------------------------------------------------------------------------- /resflows/layers/squeeze.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = ['SqueezeLayer'] 5 | 6 | 7 | class SqueezeLayer(nn.Module): 8 | 9 | def __init__(self, downscale_factor): 10 | super(SqueezeLayer, self).__init__() 11 | self.downscale_factor = downscale_factor 12 | 13 | def forward(self, x, logpx=None): 14 | squeeze_x = squeeze(x, self.downscale_factor) 15 | if logpx is None: 16 | return squeeze_x 17 | else: 18 | return squeeze_x, logpx 19 | 20 | def inverse(self, y, logpy=None): 21 | unsqueeze_y = unsqueeze(y, self.downscale_factor) 22 | if logpy is None: 23 | return unsqueeze_y 24 | else: 25 | return unsqueeze_y, logpy 26 | 27 | 28 | def unsqueeze(input, upscale_factor=2): 29 | return torch.pixel_shuffle(input, upscale_factor) 30 | 31 | 32 | def squeeze(input, downscale_factor=2): 33 | ''' 34 | [:, C, H*r, W*r] -> [:, C*r^2, H, W] 35 | ''' 36 | batch_size, in_channels, in_height, in_width = input.shape 37 | out_channels = in_channels * (downscale_factor**2) 38 | 39 | out_height = in_height // downscale_factor 40 | out_width = in_width // downscale_factor 41 | 42 | input_view = input.reshape(batch_size, in_channels, out_height, downscale_factor, out_width, downscale_factor) 43 | 44 | output = input_view.permute(0, 1, 3, 5, 2, 4) 45 | return output.reshape(batch_size, out_channels, out_height, out_width) 46 | -------------------------------------------------------------------------------- /resflows/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | 5 | class CosineAnnealingWarmRestarts(_LRScheduler): 6 | r"""Set the learning rate of each parameter group using a cosine annealing 7 | schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` 8 | is the number of epochs since the last restart and :math:`T_{i}` is the number 9 | of epochs between two warm restarts in SGDR: 10 | .. math:: 11 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + 12 | \cos(\frac{T_{cur}}{T_{i}}\pi)) 13 | When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. 14 | When :math:`T_{cur}=0`(after restart), set :math:`\eta_t=\eta_{max}`. 15 | It has been proposed in 16 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. 17 | Args: 18 | optimizer (Optimizer): Wrapped optimizer. 19 | T_0 (int): Number of iterations for the first restart. 20 | T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. 21 | eta_min (float, optional): Minimum learning rate. Default: 0. 22 | last_epoch (int, optional): The index of last epoch. Default: -1. 23 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 24 | https://arxiv.org/abs/1608.03983 25 | """ 26 | 27 | def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1): 28 | if T_0 <= 0 or not isinstance(T_0, int): 29 | raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) 30 | if T_mult < 1 or not isinstance(T_mult, int): 31 | raise ValueError("Expected integer T_mul >= 1, but got {}".format(T_mult)) 32 | self.T_0 = T_0 33 | self.T_i = T_0 34 | self.T_mult = T_mult 35 | self.eta_min = eta_min 36 | super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch) 37 | self.T_cur = last_epoch 38 | 39 | def get_lr(self): 40 | return [ 41 | self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2 42 | for base_lr in self.base_lrs 43 | ] 44 | 45 | def step(self, epoch=None): 46 | """Step could be called after every update, i.e. if one epoch has 10 iterations 47 | (number_of_train_examples / batch_size), we should call SGDR.step(0.1), SGDR.step(0.2), etc. 48 | This function can be called in an interleaved way. 49 | Example: 50 | >>> scheduler = SGDR(optimizer, T_0, T_mult) 51 | >>> for epoch in range(20): 52 | >>> scheduler.step() 53 | >>> scheduler.step(26) 54 | >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) 55 | """ 56 | if epoch is None: 57 | epoch = self.last_epoch + 1 58 | self.T_cur = self.T_cur + 1 59 | if self.T_cur >= self.T_i: 60 | self.T_cur = self.T_cur - self.T_i 61 | self.T_i = self.T_i * self.T_mult 62 | else: 63 | if epoch >= self.T_0: 64 | if self.T_mult == 1: 65 | self.T_cur = epoch % self.T_0 66 | else: 67 | n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) 68 | self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / (self.T_mult - 1) 69 | self.T_i = self.T_0 * self.T_mult**(n) 70 | else: 71 | self.T_i = self.T_0 72 | self.T_cur = epoch 73 | self.last_epoch = math.floor(epoch) 74 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 75 | param_group['lr'] = lr 76 | -------------------------------------------------------------------------------- /resflows/optimizers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class Adam(Optimizer): 7 | """Implements Adam algorithm. 8 | 9 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 10 | 11 | Arguments: 12 | params (iterable): iterable of parameters to optimize or dicts defining 13 | parameter groups 14 | lr (float, optional): learning rate (default: 1e-3) 15 | betas (Tuple[float, float], optional): coefficients used for computing 16 | running averages of gradient and its square (default: (0.9, 0.999)) 17 | eps (float, optional): term added to the denominator to improve 18 | numerical stability (default: 1e-8) 19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 20 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 21 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 22 | (default: False) 23 | 24 | .. _Adam\: A Method for Stochastic Optimization: 25 | https://arxiv.org/abs/1412.6980 26 | .. _On the Convergence of Adam and Beyond: 27 | https://openreview.net/forum?id=ryQu7f-RZ 28 | """ 29 | 30 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False): 31 | if not 0.0 <= lr: 32 | raise ValueError("Invalid learning rate: {}".format(lr)) 33 | if not 0.0 <= eps: 34 | raise ValueError("Invalid epsilon value: {}".format(eps)) 35 | if not 0.0 <= betas[0] < 1.0: 36 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 37 | if not 0.0 <= betas[1] < 1.0: 38 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 39 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) 40 | super(Adam, self).__init__(params, defaults) 41 | 42 | def __setstate__(self, state): 43 | super(Adam, self).__setstate__(state) 44 | for group in self.param_groups: 45 | group.setdefault('amsgrad', False) 46 | 47 | def step(self, closure=None): 48 | """Performs a single optimization step. 49 | 50 | Arguments: 51 | closure (callable, optional): A closure that reevaluates the model 52 | and returns the loss. 53 | """ 54 | loss = None 55 | if closure is not None: 56 | loss = closure() 57 | 58 | for group in self.param_groups: 59 | for p in group['params']: 60 | if p.grad is None: 61 | continue 62 | grad = p.grad.data 63 | if grad.is_sparse: 64 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 65 | amsgrad = group['amsgrad'] 66 | 67 | state = self.state[p] 68 | 69 | # State initialization 70 | if len(state) == 0: 71 | state['step'] = 0 72 | # Exponential moving average of gradient values 73 | state['exp_avg'] = torch.zeros_like(p.data) 74 | # Exponential moving average of squared gradient values 75 | state['exp_avg_sq'] = torch.zeros_like(p.data) 76 | if amsgrad: 77 | # Maintains max of all exp. moving avg. of sq. grad. values 78 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 79 | 80 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 81 | if amsgrad: 82 | max_exp_avg_sq = state['max_exp_avg_sq'] 83 | beta1, beta2 = group['betas'] 84 | 85 | state['step'] += 1 86 | 87 | # Decay the first and second moment running average coefficient 88 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 89 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 90 | if amsgrad: 91 | # Maintains the maximum of all 2nd moment running avg. till now 92 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 93 | # Use the max. for normalizing running avg. of gradient 94 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 95 | else: 96 | denom = exp_avg_sq.sqrt().add_(group['eps']) 97 | 98 | bias_correction1 = 1 - beta1**state['step'] 99 | bias_correction2 = 1 - beta2**state['step'] 100 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 101 | 102 | p.data.addcdiv_(-step_size, exp_avg, denom) 103 | 104 | if group['weight_decay'] != 0: 105 | p.data.add(-step_size * group['weight_decay'], p.data) 106 | 107 | return loss 108 | 109 | 110 | class Adamax(Optimizer): 111 | """Implements Adamax algorithm (a variant of Adam based on infinity norm). 112 | 113 | It has been proposed in `Adam: A Method for Stochastic Optimization`__. 114 | 115 | Arguments: 116 | params (iterable): iterable of parameters to optimize or dicts defining 117 | parameter groups 118 | lr (float, optional): learning rate (default: 2e-3) 119 | betas (Tuple[float, float], optional): coefficients used for computing 120 | running averages of gradient and its square 121 | eps (float, optional): term added to the denominator to improve 122 | numerical stability (default: 1e-8) 123 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 124 | 125 | __ https://arxiv.org/abs/1412.6980 126 | """ 127 | 128 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 129 | if not 0.0 <= lr: 130 | raise ValueError("Invalid learning rate: {}".format(lr)) 131 | if not 0.0 <= eps: 132 | raise ValueError("Invalid epsilon value: {}".format(eps)) 133 | if not 0.0 <= betas[0] < 1.0: 134 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 135 | if not 0.0 <= betas[1] < 1.0: 136 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 137 | if not 0.0 <= weight_decay: 138 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 139 | 140 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 141 | super(Adamax, self).__init__(params, defaults) 142 | 143 | def step(self, closure=None): 144 | """Performs a single optimization step. 145 | 146 | Arguments: 147 | closure (callable, optional): A closure that reevaluates the model 148 | and returns the loss. 149 | """ 150 | loss = None 151 | if closure is not None: 152 | loss = closure() 153 | 154 | for group in self.param_groups: 155 | for p in group['params']: 156 | if p.grad is None: 157 | continue 158 | grad = p.grad.data 159 | if grad.is_sparse: 160 | raise RuntimeError('Adamax does not support sparse gradients') 161 | state = self.state[p] 162 | 163 | # State initialization 164 | if len(state) == 0: 165 | state['step'] = 0 166 | state['exp_avg'] = torch.zeros_like(p.data) 167 | state['exp_inf'] = torch.zeros_like(p.data) 168 | 169 | exp_avg, exp_inf = state['exp_avg'], state['exp_inf'] 170 | beta1, beta2 = group['betas'] 171 | eps = group['eps'] 172 | 173 | state['step'] += 1 174 | 175 | # Update biased first moment estimate. 176 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 177 | # Update the exponentially weighted infinity norm. 178 | norm_buf = torch.cat([exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)], 0) 179 | torch.max(norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long())) 180 | 181 | bias_correction = 1 - beta1**state['step'] 182 | clr = group['lr'] / bias_correction 183 | 184 | p.data.addcdiv_(-clr, exp_avg, exp_inf) 185 | 186 | if group['weight_decay'] != 0: 187 | p.data.add(-clr * group['weight_decay'], p.data) 188 | 189 | return loss 190 | 191 | 192 | class RMSprop(Optimizer): 193 | """Implements RMSprop algorithm. 194 | 195 | Proposed by G. Hinton in his 196 | `course `_. 197 | 198 | The centered version first appears in `Generating Sequences 199 | With Recurrent Neural Networks `_. 200 | 201 | Arguments: 202 | params (iterable): iterable of parameters to optimize or dicts defining 203 | parameter groups 204 | lr (float, optional): learning rate (default: 1e-2) 205 | momentum (float, optional): momentum factor (default: 0) 206 | alpha (float, optional): smoothing constant (default: 0.99) 207 | eps (float, optional): term added to the denominator to improve 208 | numerical stability (default: 1e-8) 209 | centered (bool, optional) : if ``True``, compute the centered RMSProp, 210 | the gradient is normalized by an estimation of its variance 211 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 212 | 213 | """ 214 | 215 | def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False): 216 | if not 0.0 <= lr: 217 | raise ValueError("Invalid learning rate: {}".format(lr)) 218 | if not 0.0 <= eps: 219 | raise ValueError("Invalid epsilon value: {}".format(eps)) 220 | if not 0.0 <= momentum: 221 | raise ValueError("Invalid momentum value: {}".format(momentum)) 222 | if not 0.0 <= weight_decay: 223 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 224 | if not 0.0 <= alpha: 225 | raise ValueError("Invalid alpha value: {}".format(alpha)) 226 | 227 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay) 228 | super(RMSprop, self).__init__(params, defaults) 229 | 230 | def __setstate__(self, state): 231 | super(RMSprop, self).__setstate__(state) 232 | for group in self.param_groups: 233 | group.setdefault('momentum', 0) 234 | group.setdefault('centered', False) 235 | 236 | def step(self, closure=None): 237 | """Performs a single optimization step. 238 | 239 | Arguments: 240 | closure (callable, optional): A closure that reevaluates the model 241 | and returns the loss. 242 | """ 243 | loss = None 244 | if closure is not None: 245 | loss = closure() 246 | 247 | for group in self.param_groups: 248 | for p in group['params']: 249 | if p.grad is None: 250 | continue 251 | grad = p.grad.data 252 | if grad.is_sparse: 253 | raise RuntimeError('RMSprop does not support sparse gradients') 254 | state = self.state[p] 255 | 256 | # State initialization 257 | if len(state) == 0: 258 | state['step'] = 0 259 | state['square_avg'] = torch.zeros_like(p.data) 260 | if group['momentum'] > 0: 261 | state['momentum_buffer'] = torch.zeros_like(p.data) 262 | if group['centered']: 263 | state['grad_avg'] = torch.zeros_like(p.data) 264 | 265 | square_avg = state['square_avg'] 266 | alpha = group['alpha'] 267 | 268 | state['step'] += 1 269 | 270 | square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) 271 | 272 | if group['centered']: 273 | grad_avg = state['grad_avg'] 274 | grad_avg.mul_(alpha).add_(1 - alpha, grad) 275 | avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt().add_(group['eps']) 276 | else: 277 | avg = square_avg.sqrt().add_(group['eps']) 278 | 279 | if group['momentum'] > 0: 280 | buf = state['momentum_buffer'] 281 | buf.mul_(group['momentum']).addcdiv_(grad, avg) 282 | p.data.add_(-group['lr'], buf) 283 | else: 284 | p.data.addcdiv_(-group['lr'], grad, avg) 285 | 286 | if group['weight_decay'] != 0: 287 | p.data.add(-group['lr'] * group['weight_decay'], p.data) 288 | 289 | return loss 290 | -------------------------------------------------------------------------------- /resflows/resflow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | import resflows.layers as layers 6 | import resflows.layers.base as base_layers 7 | 8 | ACT_FNS = { 9 | 'softplus': lambda b: nn.Softplus(), 10 | 'elu': lambda b: nn.ELU(inplace=b), 11 | 'swish': lambda b: base_layers.Swish(), 12 | 'lcube': lambda b: base_layers.LipschitzCube(), 13 | 'identity': lambda b: base_layers.Identity(), 14 | 'relu': lambda b: nn.ReLU(inplace=b), 15 | } 16 | 17 | 18 | class ResidualFlow(nn.Module): 19 | 20 | def __init__( 21 | self, 22 | input_size, 23 | n_blocks=[16, 16], 24 | intermediate_dim=64, 25 | factor_out=True, 26 | quadratic=False, 27 | init_layer=None, 28 | actnorm=False, 29 | fc_actnorm=False, 30 | batchnorm=False, 31 | dropout=0, 32 | fc=False, 33 | coeff=0.9, 34 | vnorms='122f', 35 | n_lipschitz_iters=None, 36 | sn_atol=None, 37 | sn_rtol=None, 38 | n_power_series=5, 39 | n_dist='geometric', 40 | n_samples=1, 41 | kernels='3-1-3', 42 | activation_fn='elu', 43 | fc_end=True, 44 | fc_idim=128, 45 | n_exact_terms=0, 46 | preact=False, 47 | neumann_grad=True, 48 | grad_in_forward=False, 49 | first_resblock=False, 50 | learn_p=False, 51 | classification=False, 52 | classification_hdim=64, 53 | n_classes=10, 54 | block_type='resblock', 55 | ): 56 | super(ResidualFlow, self).__init__() 57 | self.n_scale = min(len(n_blocks), self._calc_n_scale(input_size)) 58 | self.n_blocks = n_blocks 59 | self.intermediate_dim = intermediate_dim 60 | self.factor_out = factor_out 61 | self.quadratic = quadratic 62 | self.init_layer = init_layer 63 | self.actnorm = actnorm 64 | self.fc_actnorm = fc_actnorm 65 | self.batchnorm = batchnorm 66 | self.dropout = dropout 67 | self.fc = fc 68 | self.coeff = coeff 69 | self.vnorms = vnorms 70 | self.n_lipschitz_iters = n_lipschitz_iters 71 | self.sn_atol = sn_atol 72 | self.sn_rtol = sn_rtol 73 | self.n_power_series = n_power_series 74 | self.n_dist = n_dist 75 | self.n_samples = n_samples 76 | self.kernels = kernels 77 | self.activation_fn = activation_fn 78 | self.fc_end = fc_end 79 | self.fc_idim = fc_idim 80 | self.n_exact_terms = n_exact_terms 81 | self.preact = preact 82 | self.neumann_grad = neumann_grad 83 | self.grad_in_forward = grad_in_forward 84 | self.first_resblock = first_resblock 85 | self.learn_p = learn_p 86 | self.classification = classification 87 | self.classification_hdim = classification_hdim 88 | self.n_classes = n_classes 89 | self.block_type = block_type 90 | 91 | if not self.n_scale > 0: 92 | raise ValueError('Could not compute number of scales for input of' 'size (%d,%d,%d,%d)' % input_size) 93 | 94 | self.transforms = self._build_net(input_size) 95 | 96 | self.dims = [o[1:] for o in self.calc_output_size(input_size)] 97 | 98 | if self.classification: 99 | self.build_multiscale_classifier(input_size) 100 | 101 | def _build_net(self, input_size): 102 | _, c, h, w = input_size 103 | transforms = [] 104 | _stacked_blocks = StackediResBlocks if self.block_type == 'resblock' else StackedCouplingBlocks 105 | for i in range(self.n_scale): 106 | transforms.append( 107 | _stacked_blocks( 108 | initial_size=(c, h, w), 109 | idim=self.intermediate_dim, 110 | squeeze=(i < self.n_scale - 1), # don't squeeze last layer 111 | init_layer=self.init_layer if i == 0 else None, 112 | n_blocks=self.n_blocks[i], 113 | quadratic=self.quadratic, 114 | actnorm=self.actnorm, 115 | fc_actnorm=self.fc_actnorm, 116 | batchnorm=self.batchnorm, 117 | dropout=self.dropout, 118 | fc=self.fc, 119 | coeff=self.coeff, 120 | vnorms=self.vnorms, 121 | n_lipschitz_iters=self.n_lipschitz_iters, 122 | sn_atol=self.sn_atol, 123 | sn_rtol=self.sn_rtol, 124 | n_power_series=self.n_power_series, 125 | n_dist=self.n_dist, 126 | n_samples=self.n_samples, 127 | kernels=self.kernels, 128 | activation_fn=self.activation_fn, 129 | fc_end=self.fc_end, 130 | fc_idim=self.fc_idim, 131 | n_exact_terms=self.n_exact_terms, 132 | preact=self.preact, 133 | neumann_grad=self.neumann_grad, 134 | grad_in_forward=self.grad_in_forward, 135 | first_resblock=self.first_resblock and (i == 0), 136 | learn_p=self.learn_p, 137 | ) 138 | ) 139 | c, h, w = c * 2 if self.factor_out else c * 4, h // 2, w // 2 140 | return nn.ModuleList(transforms) 141 | 142 | def _calc_n_scale(self, input_size): 143 | _, _, h, w = input_size 144 | n_scale = 0 145 | while h >= 4 and w >= 4: 146 | n_scale += 1 147 | h = h // 2 148 | w = w // 2 149 | return n_scale 150 | 151 | def calc_output_size(self, input_size): 152 | n, c, h, w = input_size 153 | if not self.factor_out: 154 | k = self.n_scale - 1 155 | return [[n, c * 4**k, h // 2**k, w // 2**k]] 156 | output_sizes = [] 157 | for i in range(self.n_scale): 158 | if i < self.n_scale - 1: 159 | c *= 2 160 | h //= 2 161 | w //= 2 162 | output_sizes.append((n, c, h, w)) 163 | else: 164 | output_sizes.append((n, c, h, w)) 165 | return tuple(output_sizes) 166 | 167 | def build_multiscale_classifier(self, input_size): 168 | n, c, h, w = input_size 169 | hidden_shapes = [] 170 | for i in range(self.n_scale): 171 | if i < self.n_scale - 1: 172 | c *= 2 if self.factor_out else 4 173 | h //= 2 174 | w //= 2 175 | hidden_shapes.append((n, c, h, w)) 176 | 177 | classification_heads = [] 178 | for i, hshape in enumerate(hidden_shapes): 179 | classification_heads.append( 180 | nn.Sequential( 181 | nn.Conv2d(hshape[1], self.classification_hdim, 3, 1, 1), 182 | layers.ActNorm2d(self.classification_hdim), 183 | nn.ReLU(inplace=True), 184 | nn.AdaptiveAvgPool2d((1, 1)), 185 | ) 186 | ) 187 | self.classification_heads = nn.ModuleList(classification_heads) 188 | self.logit_layer = nn.Linear(self.classification_hdim * len(classification_heads), self.n_classes) 189 | 190 | def forward(self, x, logpx=None, inverse=False, classify=False): 191 | if inverse: 192 | return self.inverse(x, logpx) 193 | out = [] 194 | if classify: class_outs = [] 195 | for idx in range(len(self.transforms)): 196 | if logpx is not None: 197 | x, logpx = self.transforms[idx].forward(x, logpx) 198 | else: 199 | x = self.transforms[idx].forward(x) 200 | if self.factor_out and (idx < len(self.transforms) - 1): 201 | d = x.size(1) // 2 202 | x, f = x[:, :d], x[:, d:] 203 | out.append(f) 204 | 205 | # Handle classification. 206 | if classify: 207 | if self.factor_out: 208 | class_outs.append(self.classification_heads[idx](f)) 209 | else: 210 | class_outs.append(self.classification_heads[idx](x)) 211 | 212 | out.append(x) 213 | out = torch.cat([o.view(o.size()[0], -1) for o in out], 1) 214 | output = out if logpx is None else (out, logpx) 215 | if classify: 216 | h = torch.cat(class_outs, dim=1).squeeze(-1).squeeze(-1) 217 | logits = self.logit_layer(h) 218 | return output, logits 219 | else: 220 | return output 221 | 222 | def inverse(self, z, logpz=None): 223 | if self.factor_out: 224 | z = z.view(z.shape[0], -1) 225 | zs = [] 226 | i = 0 227 | for dims in self.dims: 228 | s = np.prod(dims) 229 | zs.append(z[:, i:i + s]) 230 | i += s 231 | zs = [_z.view(_z.size()[0], *zsize) for _z, zsize in zip(zs, self.dims)] 232 | 233 | if logpz is None: 234 | z_prev = self.transforms[-1].inverse(zs[-1]) 235 | for idx in range(len(self.transforms) - 2, -1, -1): 236 | z_prev = torch.cat((z_prev, zs[idx]), dim=1) 237 | z_prev = self.transforms[idx].inverse(z_prev) 238 | return z_prev 239 | else: 240 | z_prev, logpz = self.transforms[-1].inverse(zs[-1], logpz) 241 | for idx in range(len(self.transforms) - 2, -1, -1): 242 | z_prev = torch.cat((z_prev, zs[idx]), dim=1) 243 | z_prev, logpz = self.transforms[idx].inverse(z_prev, logpz) 244 | return z_prev, logpz 245 | else: 246 | z = z.view(z.shape[0], *self.dims[-1]) 247 | for idx in range(len(self.transforms) - 1, -1, -1): 248 | if logpz is None: 249 | z = self.transforms[idx].inverse(z) 250 | else: 251 | z, logpz = self.transforms[idx].inverse(z, logpz) 252 | return z if logpz is None else (z, logpz) 253 | 254 | 255 | class StackediResBlocks(layers.SequentialFlow): 256 | 257 | def __init__( 258 | self, 259 | initial_size, 260 | idim, 261 | squeeze=True, 262 | init_layer=None, 263 | n_blocks=1, 264 | quadratic=False, 265 | actnorm=False, 266 | fc_actnorm=False, 267 | batchnorm=False, 268 | dropout=0, 269 | fc=False, 270 | coeff=0.9, 271 | vnorms='122f', 272 | n_lipschitz_iters=None, 273 | sn_atol=None, 274 | sn_rtol=None, 275 | n_power_series=5, 276 | n_dist='geometric', 277 | n_samples=1, 278 | kernels='3-1-3', 279 | activation_fn='elu', 280 | fc_end=True, 281 | fc_nblocks=4, 282 | fc_idim=128, 283 | n_exact_terms=0, 284 | preact=False, 285 | neumann_grad=True, 286 | grad_in_forward=False, 287 | first_resblock=False, 288 | learn_p=False, 289 | ): 290 | 291 | chain = [] 292 | 293 | # Parse vnorms 294 | ps = [] 295 | for p in vnorms: 296 | if p == 'f': 297 | ps.append(float('inf')) 298 | else: 299 | ps.append(float(p)) 300 | domains, codomains = ps[:-1], ps[1:] 301 | assert len(domains) == len(kernels.split('-')) 302 | 303 | def _actnorm(size, fc): 304 | if fc: 305 | return FCWrapper(layers.ActNorm1d(size[0] * size[1] * size[2])) 306 | else: 307 | return layers.ActNorm2d(size[0]) 308 | 309 | def _quadratic_layer(initial_size, fc): 310 | if fc: 311 | c, h, w = initial_size 312 | dim = c * h * w 313 | return FCWrapper(layers.InvertibleLinear(dim)) 314 | else: 315 | return layers.InvertibleConv2d(initial_size[0]) 316 | 317 | def _lipschitz_layer(fc): 318 | return base_layers.get_linear if fc else base_layers.get_conv2d 319 | 320 | def _resblock(initial_size, fc, idim=idim, first_resblock=False): 321 | if fc: 322 | return layers.iResBlock( 323 | FCNet( 324 | input_shape=initial_size, 325 | idim=idim, 326 | lipschitz_layer=_lipschitz_layer(True), 327 | nhidden=len(kernels.split('-')) - 1, 328 | coeff=coeff, 329 | domains=domains, 330 | codomains=codomains, 331 | n_iterations=n_lipschitz_iters, 332 | activation_fn=activation_fn, 333 | preact=preact, 334 | dropout=dropout, 335 | sn_atol=sn_atol, 336 | sn_rtol=sn_rtol, 337 | learn_p=learn_p, 338 | ), 339 | n_power_series=n_power_series, 340 | n_dist=n_dist, 341 | n_samples=n_samples, 342 | n_exact_terms=n_exact_terms, 343 | neumann_grad=neumann_grad, 344 | grad_in_forward=grad_in_forward, 345 | ) 346 | else: 347 | ks = list(map(int, kernels.split('-'))) 348 | if learn_p: 349 | _domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(ks))] 350 | _codomains = _domains[1:] + [_domains[0]] 351 | else: 352 | _domains = domains 353 | _codomains = codomains 354 | nnet = [] 355 | if not first_resblock and preact: 356 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) 357 | nnet.append(ACT_FNS[activation_fn](False)) 358 | nnet.append( 359 | _lipschitz_layer(fc)( 360 | initial_size[0], idim, ks[0], 1, ks[0] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, 361 | domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol 362 | ) 363 | ) 364 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) 365 | nnet.append(ACT_FNS[activation_fn](True)) 366 | for i, k in enumerate(ks[1:-1]): 367 | nnet.append( 368 | _lipschitz_layer(fc)( 369 | idim, idim, k, 1, k // 2, coeff=coeff, n_iterations=n_lipschitz_iters, 370 | domain=_domains[i + 1], codomain=_codomains[i + 1], atol=sn_atol, rtol=sn_rtol 371 | ) 372 | ) 373 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) 374 | nnet.append(ACT_FNS[activation_fn](True)) 375 | if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True)) 376 | nnet.append( 377 | _lipschitz_layer(fc)( 378 | idim, initial_size[0], ks[-1], 1, ks[-1] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, 379 | domain=_domains[-1], codomain=_codomains[-1], atol=sn_atol, rtol=sn_rtol 380 | ) 381 | ) 382 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) 383 | return layers.iResBlock( 384 | nn.Sequential(*nnet), 385 | n_power_series=n_power_series, 386 | n_dist=n_dist, 387 | n_samples=n_samples, 388 | n_exact_terms=n_exact_terms, 389 | neumann_grad=neumann_grad, 390 | grad_in_forward=grad_in_forward, 391 | ) 392 | 393 | if init_layer is not None: chain.append(init_layer) 394 | if first_resblock and actnorm: chain.append(_actnorm(initial_size, fc)) 395 | if first_resblock and fc_actnorm: chain.append(_actnorm(initial_size, True)) 396 | 397 | if squeeze: 398 | c, h, w = initial_size 399 | for i in range(n_blocks): 400 | if quadratic: chain.append(_quadratic_layer(initial_size, fc)) 401 | chain.append(_resblock(initial_size, fc, first_resblock=first_resblock and (i == 0))) 402 | if actnorm: chain.append(_actnorm(initial_size, fc)) 403 | if fc_actnorm: chain.append(_actnorm(initial_size, True)) 404 | chain.append(layers.SqueezeLayer(2)) 405 | else: 406 | for _ in range(n_blocks): 407 | if quadratic: chain.append(_quadratic_layer(initial_size, fc)) 408 | chain.append(_resblock(initial_size, fc)) 409 | if actnorm: chain.append(_actnorm(initial_size, fc)) 410 | if fc_actnorm: chain.append(_actnorm(initial_size, True)) 411 | # Use four fully connected layers at the end. 412 | if fc_end: 413 | for _ in range(fc_nblocks): 414 | chain.append(_resblock(initial_size, True, fc_idim)) 415 | if actnorm or fc_actnorm: chain.append(_actnorm(initial_size, True)) 416 | 417 | super(StackediResBlocks, self).__init__(chain) 418 | 419 | 420 | class FCNet(nn.Module): 421 | 422 | def __init__( 423 | self, input_shape, idim, lipschitz_layer, nhidden, coeff, domains, codomains, n_iterations, activation_fn, 424 | preact, dropout, sn_atol, sn_rtol, learn_p, div_in=1 425 | ): 426 | super(FCNet, self).__init__() 427 | self.input_shape = input_shape 428 | c, h, w = self.input_shape 429 | dim = c * h * w 430 | nnet = [] 431 | last_dim = dim // div_in 432 | if preact: nnet.append(ACT_FNS[activation_fn](False)) 433 | if learn_p: 434 | domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(domains))] 435 | codomains = domains[1:] + [domains[0]] 436 | for i in range(nhidden): 437 | nnet.append( 438 | lipschitz_layer(last_dim, idim) if lipschitz_layer == nn.Linear else lipschitz_layer( 439 | last_dim, idim, coeff=coeff, n_iterations=n_iterations, domain=domains[i], codomain=codomains[i], 440 | atol=sn_atol, rtol=sn_rtol 441 | ) 442 | ) 443 | nnet.append(ACT_FNS[activation_fn](True)) 444 | last_dim = idim 445 | if dropout: nnet.append(nn.Dropout(dropout, inplace=True)) 446 | nnet.append( 447 | lipschitz_layer(last_dim, dim) if lipschitz_layer == nn.Linear else lipschitz_layer( 448 | last_dim, dim, coeff=coeff, n_iterations=n_iterations, domain=domains[-1], codomain=codomains[-1], 449 | atol=sn_atol, rtol=sn_rtol 450 | ) 451 | ) 452 | self.nnet = nn.Sequential(*nnet) 453 | 454 | def forward(self, x): 455 | x = x.view(x.shape[0], -1) 456 | y = self.nnet(x) 457 | return y.view(y.shape[0], *self.input_shape) 458 | 459 | 460 | class FCWrapper(nn.Module): 461 | 462 | def __init__(self, fc_module): 463 | super(FCWrapper, self).__init__() 464 | self.fc_module = fc_module 465 | 466 | def forward(self, x, logpx=None): 467 | shape = x.shape 468 | x = x.view(x.shape[0], -1) 469 | if logpx is None: 470 | y = self.fc_module(x) 471 | return y.view(*shape) 472 | else: 473 | y, logpy = self.fc_module(x, logpx) 474 | return y.view(*shape), logpy 475 | 476 | def inverse(self, y, logpy=None): 477 | shape = y.shape 478 | y = y.view(y.shape[0], -1) 479 | if logpy is None: 480 | x = self.fc_module.inverse(y) 481 | return x.view(*shape) 482 | else: 483 | x, logpx = self.fc_module.inverse(y, logpy) 484 | return x.view(*shape), logpx 485 | 486 | 487 | class StackedCouplingBlocks(layers.SequentialFlow): 488 | 489 | def __init__( 490 | self, 491 | initial_size, 492 | idim, 493 | squeeze=True, 494 | init_layer=None, 495 | n_blocks=1, 496 | quadratic=False, 497 | actnorm=False, 498 | fc_actnorm=False, 499 | batchnorm=False, 500 | dropout=0, 501 | fc=False, 502 | coeff=0.9, 503 | vnorms='122f', 504 | n_lipschitz_iters=None, 505 | sn_atol=None, 506 | sn_rtol=None, 507 | n_power_series=5, 508 | n_dist='geometric', 509 | n_samples=1, 510 | kernels='3-1-3', 511 | activation_fn='elu', 512 | fc_end=True, 513 | fc_nblocks=4, 514 | fc_idim=128, 515 | n_exact_terms=0, 516 | preact=False, 517 | neumann_grad=True, 518 | grad_in_forward=False, 519 | first_resblock=False, 520 | learn_p=False, 521 | ): 522 | 523 | # yapf: disable 524 | class nonloc_scope: pass 525 | nonloc_scope.swap = True 526 | # yapf: enable 527 | 528 | chain = [] 529 | 530 | def _actnorm(size, fc): 531 | if fc: 532 | return FCWrapper(layers.ActNorm1d(size[0] * size[1] * size[2])) 533 | else: 534 | return layers.ActNorm2d(size[0]) 535 | 536 | def _quadratic_layer(initial_size, fc): 537 | if fc: 538 | c, h, w = initial_size 539 | dim = c * h * w 540 | return FCWrapper(layers.InvertibleLinear(dim)) 541 | else: 542 | return layers.InvertibleConv2d(initial_size[0]) 543 | 544 | def _weight_layer(fc): 545 | return nn.Linear if fc else nn.Conv2d 546 | 547 | def _resblock(initial_size, fc, idim=idim, first_resblock=False): 548 | if fc: 549 | nonloc_scope.swap = not nonloc_scope.swap 550 | return layers.CouplingBlock( 551 | initial_size[0], 552 | FCNet( 553 | input_shape=initial_size, 554 | idim=idim, 555 | lipschitz_layer=_weight_layer(True), 556 | nhidden=len(kernels.split('-')) - 1, 557 | activation_fn=activation_fn, 558 | preact=preact, 559 | dropout=dropout, 560 | coeff=None, 561 | domains=None, 562 | codomains=None, 563 | n_iterations=None, 564 | sn_atol=None, 565 | sn_rtol=None, 566 | learn_p=None, 567 | div_in=2, 568 | ), 569 | swap=nonloc_scope.swap, 570 | ) 571 | else: 572 | ks = list(map(int, kernels.split('-'))) 573 | 574 | if init_layer is None: 575 | _block = layers.ChannelCouplingBlock 576 | _mask_type = 'channel' 577 | div_in = 2 578 | mult_out = 1 579 | else: 580 | _block = layers.MaskedCouplingBlock 581 | _mask_type = 'checkerboard' 582 | div_in = 1 583 | mult_out = 2 584 | 585 | nonloc_scope.swap = not nonloc_scope.swap 586 | _mask_type += '1' if nonloc_scope.swap else '0' 587 | 588 | nnet = [] 589 | if not first_resblock and preact: 590 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) 591 | nnet.append(ACT_FNS[activation_fn](False)) 592 | nnet.append(_weight_layer(fc)(initial_size[0] // div_in, idim, ks[0], 1, ks[0] // 2)) 593 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) 594 | nnet.append(ACT_FNS[activation_fn](True)) 595 | for i, k in enumerate(ks[1:-1]): 596 | nnet.append(_weight_layer(fc)(idim, idim, k, 1, k // 2)) 597 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) 598 | nnet.append(ACT_FNS[activation_fn](True)) 599 | if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True)) 600 | nnet.append(_weight_layer(fc)(idim, initial_size[0] * mult_out, ks[-1], 1, ks[-1] // 2)) 601 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) 602 | 603 | return _block(initial_size[0], nn.Sequential(*nnet), mask_type=_mask_type) 604 | 605 | if init_layer is not None: chain.append(init_layer) 606 | if first_resblock and actnorm: chain.append(_actnorm(initial_size, fc)) 607 | if first_resblock and fc_actnorm: chain.append(_actnorm(initial_size, True)) 608 | 609 | if squeeze: 610 | c, h, w = initial_size 611 | for i in range(n_blocks): 612 | if quadratic: chain.append(_quadratic_layer(initial_size, fc)) 613 | chain.append(_resblock(initial_size, fc, first_resblock=first_resblock and (i == 0))) 614 | if actnorm: chain.append(_actnorm(initial_size, fc)) 615 | if fc_actnorm: chain.append(_actnorm(initial_size, True)) 616 | chain.append(layers.SqueezeLayer(2)) 617 | else: 618 | for _ in range(n_blocks): 619 | if quadratic: chain.append(_quadratic_layer(initial_size, fc)) 620 | chain.append(_resblock(initial_size, fc)) 621 | if actnorm: chain.append(_actnorm(initial_size, fc)) 622 | if fc_actnorm: chain.append(_actnorm(initial_size, True)) 623 | # Use four fully connected layers at the end. 624 | if fc_end: 625 | for _ in range(fc_nblocks): 626 | chain.append(_resblock(initial_size, True, fc_idim)) 627 | if actnorm or fc_actnorm: chain.append(_actnorm(initial_size, True)) 628 | 629 | super(StackedCouplingBlocks, self).__init__(chain) 630 | -------------------------------------------------------------------------------- /resflows/toy_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn 3 | import sklearn.datasets 4 | from sklearn.utils import shuffle as util_shuffle 5 | 6 | 7 | # Dataset iterator 8 | def inf_train_gen(data, batch_size=200): 9 | 10 | if data == "swissroll": 11 | data = sklearn.datasets.make_swiss_roll(n_samples=batch_size, noise=1.0)[0] 12 | data = data.astype("float32")[:, [0, 2]] 13 | data /= 5 14 | return data 15 | 16 | elif data == "circles": 17 | data = sklearn.datasets.make_circles(n_samples=batch_size, factor=.5, noise=0.08)[0] 18 | data = data.astype("float32") 19 | data *= 3 20 | return data 21 | 22 | elif data == "rings": 23 | n_samples4 = n_samples3 = n_samples2 = batch_size // 4 24 | n_samples1 = batch_size - n_samples4 - n_samples3 - n_samples2 25 | 26 | # so as not to have the first point = last point, we set endpoint=False 27 | linspace4 = np.linspace(0, 2 * np.pi, n_samples4, endpoint=False) 28 | linspace3 = np.linspace(0, 2 * np.pi, n_samples3, endpoint=False) 29 | linspace2 = np.linspace(0, 2 * np.pi, n_samples2, endpoint=False) 30 | linspace1 = np.linspace(0, 2 * np.pi, n_samples1, endpoint=False) 31 | 32 | circ4_x = np.cos(linspace4) 33 | circ4_y = np.sin(linspace4) 34 | circ3_x = np.cos(linspace4) * 0.75 35 | circ3_y = np.sin(linspace3) * 0.75 36 | circ2_x = np.cos(linspace2) * 0.5 37 | circ2_y = np.sin(linspace2) * 0.5 38 | circ1_x = np.cos(linspace1) * 0.25 39 | circ1_y = np.sin(linspace1) * 0.25 40 | 41 | X = np.vstack([ 42 | np.hstack([circ4_x, circ3_x, circ2_x, circ1_x]), 43 | np.hstack([circ4_y, circ3_y, circ2_y, circ1_y]) 44 | ]).T * 3.0 45 | X = util_shuffle(X) 46 | 47 | # Add noise 48 | X = X + np.random.normal(scale=0.08, size=X.shape) 49 | 50 | return X.astype("float32") 51 | 52 | elif data == "moons": 53 | data = sklearn.datasets.make_moons(n_samples=batch_size, noise=0.1)[0] 54 | data = data.astype("float32") 55 | data = data * 2 + np.array([-1, -0.2]) 56 | return data 57 | 58 | elif data == "8gaussians": 59 | scale = 4. 60 | centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)), 61 | (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2), 62 | 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))] 63 | centers = [(scale * x, scale * y) for x, y in centers] 64 | 65 | dataset = [] 66 | for i in range(batch_size): 67 | point = np.random.randn(2) * 0.5 68 | idx = np.random.randint(8) 69 | center = centers[idx] 70 | point[0] += center[0] 71 | point[1] += center[1] 72 | dataset.append(point) 73 | dataset = np.array(dataset, dtype="float32") 74 | dataset /= 1.414 75 | return dataset 76 | 77 | elif data == "pinwheel": 78 | radial_std = 0.3 79 | tangential_std = 0.1 80 | num_classes = 5 81 | num_per_class = batch_size // 5 82 | rate = 0.25 83 | rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False) 84 | 85 | features = np.random.randn(num_classes*num_per_class, 2) \ 86 | * np.array([radial_std, tangential_std]) 87 | features[:, 0] += 1. 88 | labels = np.repeat(np.arange(num_classes), num_per_class) 89 | 90 | angles = rads[labels] + rate * np.exp(features[:, 0]) 91 | rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)]) 92 | rotations = np.reshape(rotations.T, (-1, 2, 2)) 93 | 94 | return 2 * np.random.permutation(np.einsum("ti,tij->tj", features, rotations)) 95 | 96 | elif data == "2spirals": 97 | n = np.sqrt(np.random.rand(batch_size // 2, 1)) * 540 * (2 * np.pi) / 360 98 | d1x = -np.cos(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 99 | d1y = np.sin(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 100 | x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3 101 | x += np.random.randn(*x.shape) * 0.1 102 | return x 103 | 104 | elif data == "checkerboard": 105 | x1 = np.random.rand(batch_size) * 4 - 2 106 | x2_ = np.random.rand(batch_size) - np.random.randint(0, 2, batch_size) * 2 107 | x2 = x2_ + (np.floor(x1) % 2) 108 | return np.concatenate([x1[:, None], x2[:, None]], 1) * 2 109 | 110 | elif data == "line": 111 | x = np.random.rand(batch_size) * 5 - 2.5 112 | y = x 113 | return np.stack((x, y), 1) 114 | elif data == "cos": 115 | x = np.random.rand(batch_size) * 5 - 2.5 116 | y = np.sin(x) * 2.5 117 | return np.stack((x, y), 1) 118 | else: 119 | return inf_train_gen("8gaussians", batch_size) 120 | -------------------------------------------------------------------------------- /resflows/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from numbers import Number 4 | import logging 5 | import torch 6 | 7 | 8 | def makedirs(dirname): 9 | if not os.path.exists(dirname): 10 | os.makedirs(dirname) 11 | 12 | 13 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False): 14 | logger = logging.getLogger() 15 | if debug: 16 | level = logging.DEBUG 17 | else: 18 | level = logging.INFO 19 | logger.setLevel(level) 20 | if saving: 21 | info_file_handler = logging.FileHandler(logpath, mode="a") 22 | info_file_handler.setLevel(level) 23 | logger.addHandler(info_file_handler) 24 | if displaying: 25 | console_handler = logging.StreamHandler() 26 | console_handler.setLevel(level) 27 | logger.addHandler(console_handler) 28 | logger.info(filepath) 29 | with open(filepath, "r") as f: 30 | logger.info(f.read()) 31 | 32 | for f in package_files: 33 | logger.info(f) 34 | with open(f, "r") as package_f: 35 | logger.info(package_f.read()) 36 | 37 | return logger 38 | 39 | 40 | class AverageMeter(object): 41 | """Computes and stores the average and current value""" 42 | 43 | def __init__(self): 44 | self.reset() 45 | 46 | def reset(self): 47 | self.val = 0 48 | self.avg = 0 49 | self.sum = 0 50 | self.count = 0 51 | 52 | def update(self, val, n=1): 53 | self.val = val 54 | self.sum += val * n 55 | self.count += n 56 | self.avg = self.sum / self.count 57 | 58 | 59 | class RunningAverageMeter(object): 60 | """Computes and stores the average and current value""" 61 | 62 | def __init__(self, momentum=0.99): 63 | self.momentum = momentum 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = None 68 | self.avg = 0 69 | 70 | def update(self, val): 71 | if self.val is None: 72 | self.avg = val 73 | else: 74 | self.avg = self.avg * self.momentum + val * (1 - self.momentum) 75 | self.val = val 76 | 77 | 78 | def inf_generator(iterable): 79 | """Allows training with DataLoaders in a single infinite loop: 80 | for i, (x, y) in enumerate(inf_generator(train_loader)): 81 | """ 82 | iterator = iterable.__iter__() 83 | while True: 84 | try: 85 | yield iterator.__next__() 86 | except StopIteration: 87 | iterator = iterable.__iter__() 88 | 89 | 90 | def save_checkpoint(state, save, epoch, last_checkpoints=None, num_checkpoints=None): 91 | if not os.path.exists(save): 92 | os.makedirs(save) 93 | filename = os.path.join(save, 'checkpt-%04d.pth' % epoch) 94 | torch.save(state, filename) 95 | 96 | if last_checkpoints is not None and num_checkpoints is not None: 97 | last_checkpoints.append(epoch) 98 | if len(last_checkpoints) > num_checkpoints: 99 | rm_epoch = last_checkpoints.pop(0) 100 | os.remove(os.path.join(save, 'checkpt-%04d.pth' % rm_epoch)) 101 | 102 | 103 | def isnan(tensor): 104 | return (tensor != tensor) 105 | 106 | 107 | def logsumexp(value, dim=None, keepdim=False): 108 | """Numerically stable implementation of the operation 109 | value.exp().sum(dim, keepdim).log() 110 | """ 111 | if dim is not None: 112 | m, _ = torch.max(value, dim=dim, keepdim=True) 113 | value0 = value - m 114 | if keepdim is False: 115 | m = m.squeeze(dim) 116 | return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim)) 117 | else: 118 | m = torch.max(value) 119 | sum_exp = torch.sum(torch.exp(value - m)) 120 | if isinstance(sum_exp, Number): 121 | return m + math.log(sum_exp) 122 | else: 123 | return m + torch.log(sum_exp) 124 | 125 | 126 | class ExponentialMovingAverage(object): 127 | 128 | def __init__(self, module, decay=0.999): 129 | """Initializes the model when .apply() is called the first time. 130 | This is to take into account data-dependent initialization that occurs in the first iteration.""" 131 | self.module = module 132 | self.decay = decay 133 | self.shadow_params = {} 134 | self.nparams = sum(p.numel() for p in module.parameters()) 135 | 136 | def init(self): 137 | for name, param in self.module.named_parameters(): 138 | self.shadow_params[name] = param.data.clone() 139 | 140 | def apply(self): 141 | if len(self.shadow_params) == 0: 142 | self.init() 143 | else: 144 | with torch.no_grad(): 145 | for name, param in self.module.named_parameters(): 146 | self.shadow_params[name] -= (1 - self.decay) * (self.shadow_params[name] - param.data) 147 | 148 | def set(self, other_ema): 149 | self.init() 150 | with torch.no_grad(): 151 | for name, param in other_ema.shadow_params.items(): 152 | self.shadow_params[name].copy_(param) 153 | 154 | def replace_with_ema(self): 155 | for name, param in self.module.named_parameters(): 156 | param.data.copy_(self.shadow_params[name]) 157 | 158 | def swap(self): 159 | for name, param in self.module.named_parameters(): 160 | tmp = self.shadow_params[name].clone() 161 | self.shadow_params[name].copy_(param.data) 162 | param.data.copy_(tmp) 163 | 164 | def __repr__(self): 165 | return ( 166 | '{}(decay={}, module={}, nparams={})'.format( 167 | self.__class__.__name__, self.decay, self.module.__class__.__name__, self.nparams 168 | ) 169 | ) 170 | -------------------------------------------------------------------------------- /resflows/visualize_flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | matplotlib.use("Agg") 4 | import matplotlib.pyplot as plt 5 | import torch 6 | 7 | LOW = -4 8 | HIGH = 4 9 | 10 | 11 | def plt_potential_func(potential, ax, npts=100, title="$p(x)$"): 12 | """ 13 | Args: 14 | potential: computes U(z_k) given z_k 15 | """ 16 | xside = np.linspace(LOW, HIGH, npts) 17 | yside = np.linspace(LOW, HIGH, npts) 18 | xx, yy = np.meshgrid(xside, yside) 19 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 20 | 21 | z = torch.Tensor(z) 22 | u = potential(z).cpu().numpy() 23 | p = np.exp(-u).reshape(npts, npts) 24 | 25 | plt.pcolormesh(xx, yy, p) 26 | ax.invert_yaxis() 27 | ax.get_xaxis().set_ticks([]) 28 | ax.get_yaxis().set_ticks([]) 29 | ax.set_title(title) 30 | 31 | 32 | def plt_flow(prior_logdensity, transform, ax, npts=100, title="$q(x)$", device="cpu"): 33 | """ 34 | Args: 35 | transform: computes z_k and log(q_k) given z_0 36 | """ 37 | side = np.linspace(LOW, HIGH, npts) 38 | xx, yy = np.meshgrid(side, side) 39 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 40 | 41 | z = torch.tensor(z, requires_grad=True).type(torch.float32).to(device) 42 | logqz = prior_logdensity(z) 43 | logqz = torch.sum(logqz, dim=1)[:, None] 44 | z, logqz = transform(z, logqz) 45 | logqz = torch.sum(logqz, dim=1)[:, None] 46 | 47 | xx = z[:, 0].cpu().numpy().reshape(npts, npts) 48 | yy = z[:, 1].cpu().numpy().reshape(npts, npts) 49 | qz = np.exp(logqz.cpu().numpy()).reshape(npts, npts) 50 | 51 | plt.pcolormesh(xx, yy, qz) 52 | ax.set_xlim(LOW, HIGH) 53 | ax.set_ylim(LOW, HIGH) 54 | cmap = matplotlib.cm.get_cmap(None) 55 | ax.set_facecolor(cmap(0.)) 56 | ax.invert_yaxis() 57 | ax.get_xaxis().set_ticks([]) 58 | ax.get_yaxis().set_ticks([]) 59 | ax.set_title(title) 60 | 61 | 62 | def plt_flow_density(prior_logdensity, inverse_transform, ax, npts=100, memory=100, title="$q(x)$", device="cpu"): 63 | side = np.linspace(LOW, HIGH, npts) 64 | xx, yy = np.meshgrid(side, side) 65 | x = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 66 | 67 | x = torch.from_numpy(x).type(torch.float32).to(device) 68 | zeros = torch.zeros(x.shape[0], 1).to(x) 69 | 70 | z, delta_logp = [], [] 71 | inds = torch.arange(0, x.shape[0]).to(torch.int64) 72 | for ii in torch.split(inds, int(memory**2)): 73 | z_, delta_logp_ = inverse_transform(x[ii], zeros[ii]) 74 | z.append(z_) 75 | delta_logp.append(delta_logp_) 76 | z = torch.cat(z, 0) 77 | delta_logp = torch.cat(delta_logp, 0) 78 | 79 | logpz = prior_logdensity(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z) 80 | logpx = logpz - delta_logp 81 | 82 | px = np.exp(logpx.cpu().numpy()).reshape(npts, npts) 83 | 84 | ax.imshow(px, cmap='inferno') 85 | ax.get_xaxis().set_ticks([]) 86 | ax.get_yaxis().set_ticks([]) 87 | ax.set_title(title) 88 | 89 | 90 | def plt_flow_samples(prior_sample, transform, ax, npts=100, memory=100, title="$x ~ q(x)$", device="cpu"): 91 | z = prior_sample(npts * npts, 2).type(torch.float32).to(device) 92 | zk = [] 93 | inds = torch.arange(0, z.shape[0]).to(torch.int64) 94 | for ii in torch.split(inds, int(memory**2)): 95 | zk.append(transform(z[ii])) 96 | zk = torch.cat(zk, 0).cpu().numpy() 97 | ax.hist2d(zk[:, 0], zk[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts, cmap='inferno') 98 | ax.invert_yaxis() 99 | ax.get_xaxis().set_ticks([]) 100 | ax.get_yaxis().set_ticks([]) 101 | ax.set_title(title) 102 | 103 | 104 | def plt_samples(samples, ax, npts=100, title="$x ~ p(x)$"): 105 | ax.hist2d(samples[:, 0], samples[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts, cmap='inferno') 106 | ax.invert_yaxis() 107 | ax.get_xaxis().set_ticks([]) 108 | ax.get_yaxis().set_ticks([]) 109 | ax.set_title(title) 110 | 111 | 112 | def visualize_transform( 113 | potential_or_samples, prior_sample, prior_density, transform=None, inverse_transform=None, samples=True, npts=100, 114 | memory=100, device="cpu" 115 | ): 116 | """Produces visualization for the model density and samples from the model.""" 117 | plt.clf() 118 | ax = plt.subplot(1, 3, 1, aspect="equal") 119 | if samples: 120 | plt_samples(potential_or_samples, ax, npts=npts) 121 | else: 122 | plt_potential_func(potential_or_samples, ax, npts=npts) 123 | 124 | ax = plt.subplot(1, 3, 2, aspect="equal") 125 | if inverse_transform is None: 126 | plt_flow(prior_density, transform, ax, npts=npts, device=device) 127 | else: 128 | plt_flow_density(prior_density, inverse_transform, ax, npts=npts, memory=memory, device=device) 129 | 130 | ax = plt.subplot(1, 3, 3, aspect="equal") 131 | if transform is not None: 132 | plt_flow_samples(prior_sample, transform, ax, npts=npts, memory=memory, device=device) 133 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup( 4 | name="resflows", 5 | version="0.0.1", 6 | py_modules=["resflows"], 7 | ) -------------------------------------------------------------------------------- /train_img.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import math 4 | import os 5 | import os.path 6 | import numpy as np 7 | from tqdm import tqdm 8 | import gc 9 | 10 | import torch 11 | import torchvision.transforms as transforms 12 | from torchvision.utils import save_image 13 | import torchvision.datasets as vdsets 14 | 15 | from resflows.resflow import ACT_FNS, ResidualFlow 16 | import resflows.datasets as datasets 17 | import resflows.optimizers as optim 18 | import resflows.utils as utils 19 | import resflows.layers as layers 20 | import resflows.layers.base as base_layers 21 | from resflows.lr_scheduler import CosineAnnealingWarmRestarts 22 | 23 | # Arguments 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | '--data', type=str, default='cifar10', choices=[ 27 | 'mnist', 28 | 'cifar10', 29 | 'svhn', 30 | 'celebahq', 31 | 'celeba_5bit', 32 | 'imagenet32', 33 | 'imagenet64', 34 | ] 35 | ) 36 | parser.add_argument('--dataroot', type=str, default='data') 37 | parser.add_argument('--imagesize', type=int, default=32) 38 | parser.add_argument('--nbits', type=int, default=8) # Only used for celebahq. 39 | 40 | parser.add_argument('--block', type=str, choices=['resblock', 'coupling'], default='resblock') 41 | 42 | parser.add_argument('--coeff', type=float, default=0.98) 43 | parser.add_argument('--vnorms', type=str, default='2222') 44 | parser.add_argument('--n-lipschitz-iters', type=int, default=None) 45 | parser.add_argument('--sn-tol', type=float, default=1e-3) 46 | parser.add_argument('--learn-p', type=eval, choices=[True, False], default=False) 47 | 48 | parser.add_argument('--n-power-series', type=int, default=None) 49 | parser.add_argument('--factor-out', type=eval, choices=[True, False], default=False) 50 | parser.add_argument('--n-dist', choices=['geometric', 'poisson'], default='poisson') 51 | parser.add_argument('--n-samples', type=int, default=1) 52 | parser.add_argument('--n-exact-terms', type=int, default=2) 53 | parser.add_argument('--var-reduc-lr', type=float, default=0) 54 | parser.add_argument('--neumann-grad', type=eval, choices=[True, False], default=True) 55 | parser.add_argument('--mem-eff', type=eval, choices=[True, False], default=True) 56 | 57 | parser.add_argument('--act', type=str, choices=ACT_FNS.keys(), default='swish') 58 | parser.add_argument('--idim', type=int, default=512) 59 | parser.add_argument('--nblocks', type=str, default='16-16-16') 60 | parser.add_argument('--squeeze-first', type=eval, default=False, choices=[True, False]) 61 | parser.add_argument('--actnorm', type=eval, default=True, choices=[True, False]) 62 | parser.add_argument('--fc-actnorm', type=eval, default=False, choices=[True, False]) 63 | parser.add_argument('--batchnorm', type=eval, default=False, choices=[True, False]) 64 | parser.add_argument('--dropout', type=float, default=0.) 65 | parser.add_argument('--fc', type=eval, default=False, choices=[True, False]) 66 | parser.add_argument('--kernels', type=str, default='3-1-3') 67 | parser.add_argument('--logit-transform', type=eval, choices=[True, False], default=True) 68 | parser.add_argument('--add-noise', type=eval, choices=[True, False], default=True) 69 | parser.add_argument('--quadratic', type=eval, choices=[True, False], default=False) 70 | parser.add_argument('--fc-end', type=eval, choices=[True, False], default=True) 71 | parser.add_argument('--fc-idim', type=int, default=128) 72 | parser.add_argument('--preact', type=eval, choices=[True, False], default=True) 73 | parser.add_argument('--padding', type=int, default=0) 74 | parser.add_argument('--first-resblock', type=eval, choices=[True, False], default=True) 75 | parser.add_argument('--cdim', type=int, default=256) 76 | 77 | parser.add_argument('--optimizer', type=str, choices=['adam', 'adamax', 'rmsprop', 'sgd'], default='adam') 78 | parser.add_argument('--scheduler', type=eval, choices=[True, False], default=False) 79 | parser.add_argument('--nepochs', help='Number of epochs for training', type=int, default=1000) 80 | parser.add_argument('--batchsize', help='Minibatch size', type=int, default=64) 81 | parser.add_argument('--lr', help='Learning rate', type=float, default=1e-3) 82 | parser.add_argument('--wd', help='Weight decay', type=float, default=0) 83 | parser.add_argument('--warmup-iters', type=int, default=1000) 84 | parser.add_argument('--annealing-iters', type=int, default=0) 85 | parser.add_argument('--save', help='directory to save results', type=str, default='experiment1') 86 | parser.add_argument('--val-batchsize', help='minibatch size', type=int, default=200) 87 | parser.add_argument('--seed', type=int, default=None) 88 | parser.add_argument('--ema-val', type=eval, choices=[True, False], default=True) 89 | parser.add_argument('--update-freq', type=int, default=1) 90 | 91 | parser.add_argument('--task', type=str, choices=['density', 'classification', 'hybrid'], default='density') 92 | parser.add_argument('--scale-dim', type=eval, choices=[True, False], default=False) 93 | parser.add_argument('--rcrop-pad-mode', type=str, choices=['constant', 'reflect'], default='reflect') 94 | parser.add_argument('--padding-dist', type=str, choices=['uniform', 'gaussian'], default='uniform') 95 | 96 | parser.add_argument('--resume', type=str, default=None) 97 | parser.add_argument('--begin-epoch', type=int, default=0) 98 | 99 | parser.add_argument('--nworkers', type=int, default=4) 100 | parser.add_argument('--print-freq', help='Print progress every so iterations', type=int, default=20) 101 | parser.add_argument('--vis-freq', help='Visualize progress every so iterations', type=int, default=500) 102 | args = parser.parse_args() 103 | 104 | # Random seed 105 | if args.seed is None: 106 | args.seed = np.random.randint(100000) 107 | 108 | # logger 109 | utils.makedirs(args.save) 110 | logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) 111 | logger.info(args) 112 | 113 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 114 | 115 | if device.type == 'cuda': 116 | logger.info('Found {} CUDA devices.'.format(torch.cuda.device_count())) 117 | for i in range(torch.cuda.device_count()): 118 | props = torch.cuda.get_device_properties(i) 119 | logger.info('{} \t Memory: {:.2f}GB'.format(props.name, props.total_memory / (1024**3))) 120 | else: 121 | logger.info('WARNING: Using device {}'.format(device)) 122 | 123 | np.random.seed(args.seed) 124 | torch.manual_seed(args.seed) 125 | if device.type == 'cuda': 126 | torch.cuda.manual_seed(args.seed) 127 | 128 | 129 | def geometric_logprob(ns, p): 130 | return torch.log(1 - p + 1e-10) * (ns - 1) + torch.log(p + 1e-10) 131 | 132 | 133 | def standard_normal_sample(size): 134 | return torch.randn(size) 135 | 136 | 137 | def standard_normal_logprob(z): 138 | logZ = -0.5 * math.log(2 * math.pi) 139 | return logZ - z.pow(2) / 2 140 | 141 | 142 | def normal_logprob(z, mean, log_std): 143 | mean = mean + torch.tensor(0.) 144 | log_std = log_std + torch.tensor(0.) 145 | c = torch.tensor([math.log(2 * math.pi)]).to(z) 146 | inv_sigma = torch.exp(-log_std) 147 | tmp = (z - mean) * inv_sigma 148 | return -0.5 * (tmp * tmp + 2 * log_std + c) 149 | 150 | 151 | def count_parameters(model): 152 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 153 | 154 | 155 | def reduce_bits(x): 156 | if args.nbits < 8: 157 | x = x * 255 158 | x = torch.floor(x / 2**(8 - args.nbits)) 159 | x = x / 2**args.nbits 160 | return x 161 | 162 | 163 | def add_noise(x, nvals=256): 164 | """ 165 | [0, 1] -> [0, nvals] -> add noise -> [0, 1] 166 | """ 167 | if args.add_noise: 168 | noise = x.new().resize_as_(x).uniform_() 169 | x = x * (nvals - 1) + noise 170 | x = x / nvals 171 | return x 172 | 173 | 174 | def update_lr(optimizer, itr): 175 | iter_frac = min(float(itr + 1) / max(args.warmup_iters, 1), 1.0) 176 | lr = args.lr * iter_frac 177 | for param_group in optimizer.param_groups: 178 | param_group["lr"] = lr 179 | 180 | 181 | def add_padding(x, nvals=256): 182 | # Theoretically, padding should've been added before the add_noise preprocessing. 183 | # nvals takes into account the preprocessing before padding is added. 184 | if args.padding > 0: 185 | if args.padding_dist == 'uniform': 186 | u = x.new_empty(x.shape[0], args.padding, x.shape[2], x.shape[3]).uniform_() 187 | logpu = torch.zeros_like(u).sum([1, 2, 3]).view(-1, 1) 188 | return torch.cat([x, u / nvals], dim=1), logpu 189 | elif args.padding_dist == 'gaussian': 190 | u = x.new_empty(x.shape[0], args.padding, x.shape[2], x.shape[3]).normal_(nvals / 2, nvals / 8) 191 | logpu = normal_logprob(u, nvals / 2, math.log(nvals / 8)).sum([1, 2, 3]).view(-1, 1) 192 | return torch.cat([x, u / nvals], dim=1), logpu 193 | else: 194 | raise ValueError() 195 | else: 196 | return x, torch.zeros(x.shape[0], 1).to(x) 197 | 198 | 199 | def remove_padding(x): 200 | if args.padding > 0: 201 | return x[:, :im_dim, :, :] 202 | else: 203 | return x 204 | 205 | 206 | logger.info('Loading dataset {}'.format(args.data)) 207 | # Dataset and hyperparameters 208 | if args.data == 'cifar10': 209 | im_dim = 3 210 | n_classes = 10 211 | if args.task in ['classification', 'hybrid']: 212 | 213 | # Classification-specific preprocessing. 214 | transform_train = transforms.Compose([ 215 | transforms.Resize(args.imagesize), 216 | transforms.RandomCrop(32, padding=4, padding_mode=args.rcrop_pad_mode), 217 | transforms.RandomHorizontalFlip(), 218 | transforms.ToTensor(), 219 | add_noise, 220 | ]) 221 | 222 | transform_test = transforms.Compose([ 223 | transforms.Resize(args.imagesize), 224 | transforms.ToTensor(), 225 | add_noise, 226 | ]) 227 | 228 | # Remove the logit transform. 229 | init_layer = layers.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 230 | else: 231 | transform_train = transforms.Compose([ 232 | transforms.Resize(args.imagesize), 233 | transforms.RandomHorizontalFlip(), 234 | transforms.ToTensor(), 235 | add_noise, 236 | ]) 237 | transform_test = transforms.Compose([ 238 | transforms.Resize(args.imagesize), 239 | transforms.ToTensor(), 240 | add_noise, 241 | ]) 242 | init_layer = layers.LogitTransform(0.05) if args.logit_transform else None 243 | train_loader = torch.utils.data.DataLoader( 244 | datasets.CIFAR10(args.dataroot, train=True, transform=transform_train), 245 | batch_size=args.batchsize, 246 | shuffle=True, 247 | num_workers=args.nworkers, 248 | ) 249 | test_loader = torch.utils.data.DataLoader( 250 | datasets.CIFAR10(args.dataroot, train=False, transform=transform_test), 251 | batch_size=args.val_batchsize, 252 | shuffle=False, 253 | num_workers=args.nworkers, 254 | ) 255 | elif args.data == 'mnist': 256 | im_dim = 1 257 | init_layer = layers.LogitTransform(1e-6) if args.logit_transform else None 258 | n_classes = 10 259 | train_loader = torch.utils.data.DataLoader( 260 | datasets.MNIST( 261 | args.dataroot, train=True, transform=transforms.Compose([ 262 | transforms.Resize(args.imagesize), 263 | transforms.ToTensor(), 264 | add_noise, 265 | ]) 266 | ), 267 | batch_size=args.batchsize, 268 | shuffle=True, 269 | num_workers=args.nworkers, 270 | ) 271 | test_loader = torch.utils.data.DataLoader( 272 | datasets.MNIST( 273 | args.dataroot, train=False, transform=transforms.Compose([ 274 | transforms.Resize(args.imagesize), 275 | transforms.ToTensor(), 276 | add_noise, 277 | ]) 278 | ), 279 | batch_size=args.val_batchsize, 280 | shuffle=False, 281 | num_workers=args.nworkers, 282 | ) 283 | elif args.data == 'svhn': 284 | im_dim = 3 285 | init_layer = layers.LogitTransform(0.05) if args.logit_transform else None 286 | n_classes = 10 287 | train_loader = torch.utils.data.DataLoader( 288 | vdsets.SVHN( 289 | args.dataroot, split='train', download=True, transform=transforms.Compose([ 290 | transforms.Resize(args.imagesize), 291 | transforms.RandomCrop(32, padding=4, padding_mode=args.rcrop_pad_mode), 292 | transforms.ToTensor(), 293 | add_noise, 294 | ]) 295 | ), 296 | batch_size=args.batchsize, 297 | shuffle=True, 298 | num_workers=args.nworkers, 299 | ) 300 | test_loader = torch.utils.data.DataLoader( 301 | vdsets.SVHN( 302 | args.dataroot, split='test', download=True, transform=transforms.Compose([ 303 | transforms.Resize(args.imagesize), 304 | transforms.ToTensor(), 305 | add_noise, 306 | ]) 307 | ), 308 | batch_size=args.val_batchsize, 309 | shuffle=False, 310 | num_workers=args.nworkers, 311 | ) 312 | elif args.data == 'celebahq': 313 | im_dim = 3 314 | init_layer = layers.LogitTransform(0.05) if args.logit_transform else None 315 | if args.imagesize != 256: 316 | logger.info('Changing image size to 256.') 317 | args.imagesize = 256 318 | train_loader = torch.utils.data.DataLoader( 319 | datasets.CelebAHQ( 320 | train=True, transform=transforms.Compose([ 321 | transforms.ToPILImage(), 322 | transforms.RandomHorizontalFlip(), 323 | transforms.ToTensor(), 324 | reduce_bits, 325 | lambda x: add_noise(x, nvals=2**args.nbits), 326 | ]) 327 | ), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers 328 | ) 329 | test_loader = torch.utils.data.DataLoader( 330 | datasets.CelebAHQ( 331 | train=False, transform=transforms.Compose([ 332 | reduce_bits, 333 | lambda x: add_noise(x, nvals=2**args.nbits), 334 | ]) 335 | ), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers 336 | ) 337 | elif args.data == 'celeba_5bit': 338 | im_dim = 3 339 | init_layer = layers.LogitTransform(0.05) if args.logit_transform else None 340 | if args.imagesize != 64: 341 | logger.info('Changing image size to 64.') 342 | args.imagesize = 64 343 | train_loader = torch.utils.data.DataLoader( 344 | datasets.CelebA5bit( 345 | train=True, transform=transforms.Compose([ 346 | transforms.ToPILImage(), 347 | transforms.RandomHorizontalFlip(), 348 | transforms.ToTensor(), 349 | lambda x: add_noise(x, nvals=32), 350 | ]) 351 | ), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers 352 | ) 353 | test_loader = torch.utils.data.DataLoader( 354 | datasets.CelebA5bit(train=False, transform=transforms.Compose([ 355 | lambda x: add_noise(x, nvals=32), 356 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers 357 | ) 358 | elif args.data == 'imagenet32': 359 | im_dim = 3 360 | init_layer = layers.LogitTransform(0.05) if args.logit_transform else None 361 | if args.imagesize != 32: 362 | logger.info('Changing image size to 32.') 363 | args.imagesize = 32 364 | train_loader = torch.utils.data.DataLoader( 365 | datasets.Imagenet32(train=True, transform=transforms.Compose([ 366 | add_noise, 367 | ])), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers 368 | ) 369 | test_loader = torch.utils.data.DataLoader( 370 | datasets.Imagenet32(train=False, transform=transforms.Compose([ 371 | add_noise, 372 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers 373 | ) 374 | elif args.data == 'imagenet64': 375 | im_dim = 3 376 | init_layer = layers.LogitTransform(0.05) if args.logit_transform else None 377 | if args.imagesize != 64: 378 | logger.info('Changing image size to 64.') 379 | args.imagesize = 64 380 | train_loader = torch.utils.data.DataLoader( 381 | datasets.Imagenet64(train=True, transform=transforms.Compose([ 382 | add_noise, 383 | ])), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers 384 | ) 385 | test_loader = torch.utils.data.DataLoader( 386 | datasets.Imagenet64(train=False, transform=transforms.Compose([ 387 | add_noise, 388 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers 389 | ) 390 | 391 | if args.task in ['classification', 'hybrid']: 392 | try: 393 | n_classes 394 | except NameError: 395 | raise ValueError('Cannot perform classification with {}'.format(args.data)) 396 | else: 397 | n_classes = 1 398 | 399 | logger.info('Dataset loaded.') 400 | logger.info('Creating model.') 401 | 402 | input_size = (args.batchsize, im_dim + args.padding, args.imagesize, args.imagesize) 403 | dataset_size = len(train_loader.dataset) 404 | 405 | if args.squeeze_first: 406 | input_size = (input_size[0], input_size[1] * 4, input_size[2] // 2, input_size[3] // 2) 407 | squeeze_layer = layers.SqueezeLayer(2) 408 | 409 | # Model 410 | model = ResidualFlow( 411 | input_size, 412 | n_blocks=list(map(int, args.nblocks.split('-'))), 413 | intermediate_dim=args.idim, 414 | factor_out=args.factor_out, 415 | quadratic=args.quadratic, 416 | init_layer=init_layer, 417 | actnorm=args.actnorm, 418 | fc_actnorm=args.fc_actnorm, 419 | batchnorm=args.batchnorm, 420 | dropout=args.dropout, 421 | fc=args.fc, 422 | coeff=args.coeff, 423 | vnorms=args.vnorms, 424 | n_lipschitz_iters=args.n_lipschitz_iters, 425 | sn_atol=args.sn_tol, 426 | sn_rtol=args.sn_tol, 427 | n_power_series=args.n_power_series, 428 | n_dist=args.n_dist, 429 | n_samples=args.n_samples, 430 | kernels=args.kernels, 431 | activation_fn=args.act, 432 | fc_end=args.fc_end, 433 | fc_idim=args.fc_idim, 434 | n_exact_terms=args.n_exact_terms, 435 | preact=args.preact, 436 | neumann_grad=args.neumann_grad, 437 | grad_in_forward=args.mem_eff, 438 | first_resblock=args.first_resblock, 439 | learn_p=args.learn_p, 440 | classification=args.task in ['classification', 'hybrid'], 441 | classification_hdim=args.cdim, 442 | n_classes=n_classes, 443 | block_type=args.block, 444 | ) 445 | 446 | model.to(device) 447 | ema = utils.ExponentialMovingAverage(model) 448 | 449 | 450 | def parallelize(model): 451 | return torch.nn.DataParallel(model) 452 | 453 | 454 | logger.info(model) 455 | logger.info('EMA: {}'.format(ema)) 456 | 457 | 458 | # Optimization 459 | def tensor_in(t, a): 460 | for a_ in a: 461 | if t is a_: 462 | return True 463 | return False 464 | 465 | 466 | scheduler = None 467 | 468 | if args.optimizer == 'adam': 469 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.wd) 470 | if args.scheduler: scheduler = CosineAnnealingWarmRestarts(optimizer, 20, T_mult=2, last_epoch=args.begin_epoch - 1) 471 | elif args.optimizer == 'adamax': 472 | optimizer = optim.Adamax(model.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.wd) 473 | elif args.optimizer == 'rmsprop': 474 | optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.wd) 475 | elif args.optimizer == 'sgd': 476 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.wd) 477 | if args.scheduler: 478 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 479 | optimizer, milestones=[60, 120, 160], gamma=0.2, last_epoch=args.begin_epoch - 1 480 | ) 481 | else: 482 | raise ValueError('Unknown optimizer {}'.format(args.optimizer)) 483 | 484 | best_test_bpd = math.inf 485 | if (args.resume is not None): 486 | logger.info('Resuming model from {}'.format(args.resume)) 487 | with torch.no_grad(): 488 | x = torch.rand(1, *input_size[1:]).to(device) 489 | model(x) 490 | checkpt = torch.load(args.resume) 491 | sd = {k: v for k, v in checkpt['state_dict'].items() if 'last_n_samples' not in k} 492 | state = model.state_dict() 493 | state.update(sd) 494 | model.load_state_dict(state, strict=True) 495 | ema.set(checkpt['ema']) 496 | if 'optimizer_state_dict' in checkpt: 497 | optimizer.load_state_dict(checkpt['optimizer_state_dict']) 498 | # Manually move optimizer state to GPU 499 | for state in optimizer.state.values(): 500 | for k, v in state.items(): 501 | if torch.is_tensor(v): 502 | state[k] = v.to(device) 503 | del checkpt 504 | del state 505 | 506 | logger.info(optimizer) 507 | 508 | fixed_z = standard_normal_sample([min(32, args.batchsize), 509 | (im_dim + args.padding) * args.imagesize * args.imagesize]).to(device) 510 | 511 | criterion = torch.nn.CrossEntropyLoss() 512 | 513 | 514 | def compute_loss(x, model, beta=1.0): 515 | bits_per_dim, logits_tensor = torch.zeros(1).to(x), torch.zeros(n_classes).to(x) 516 | logpz, delta_logp = torch.zeros(1).to(x), torch.zeros(1).to(x) 517 | 518 | if args.data == 'celeba_5bit': 519 | nvals = 32 520 | elif args.data == 'celebahq': 521 | nvals = 2**args.nbits 522 | else: 523 | nvals = 256 524 | 525 | x, logpu = add_padding(x, nvals) 526 | 527 | if args.squeeze_first: 528 | x = squeeze_layer(x) 529 | 530 | if args.task == 'hybrid': 531 | z_logp, logits_tensor = model(x.view(-1, *input_size[1:]), 0, classify=True) 532 | z, delta_logp = z_logp 533 | elif args.task == 'density': 534 | z, delta_logp = model(x.view(-1, *input_size[1:]), 0) 535 | elif args.task == 'classification': 536 | z, logits_tensor = model(x.view(-1, *input_size[1:]), classify=True) 537 | 538 | if args.task in ['density', 'hybrid']: 539 | # log p(z) 540 | logpz = standard_normal_logprob(z).view(z.size(0), -1).sum(1, keepdim=True) 541 | 542 | # log p(x) 543 | logpx = logpz - beta * delta_logp - np.log(nvals) * ( 544 | args.imagesize * args.imagesize * (im_dim + args.padding) 545 | ) - logpu 546 | bits_per_dim = -torch.mean(logpx) / (args.imagesize * args.imagesize * im_dim) / np.log(2) 547 | 548 | logpz = torch.mean(logpz).detach() 549 | delta_logp = torch.mean(-delta_logp).detach() 550 | 551 | return bits_per_dim, logits_tensor, logpz, delta_logp 552 | 553 | 554 | def estimator_moments(model, baseline=0): 555 | avg_first_moment = 0. 556 | avg_second_moment = 0. 557 | for m in model.modules(): 558 | if isinstance(m, layers.iResBlock): 559 | avg_first_moment += m.last_firmom.item() 560 | avg_second_moment += m.last_secmom.item() 561 | return avg_first_moment, avg_second_moment 562 | 563 | 564 | def compute_p_grads(model): 565 | scales = 0. 566 | nlayers = 0 567 | for m in model.modules(): 568 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear): 569 | scales = scales + m.compute_one_iter() 570 | nlayers += 1 571 | scales.mul(1 / nlayers).backward() 572 | for m in model.modules(): 573 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear): 574 | if m.domain.grad is not None and torch.isnan(m.domain.grad): 575 | m.domain.grad = None 576 | 577 | 578 | batch_time = utils.RunningAverageMeter(0.97) 579 | bpd_meter = utils.RunningAverageMeter(0.97) 580 | logpz_meter = utils.RunningAverageMeter(0.97) 581 | deltalogp_meter = utils.RunningAverageMeter(0.97) 582 | firmom_meter = utils.RunningAverageMeter(0.97) 583 | secmom_meter = utils.RunningAverageMeter(0.97) 584 | gnorm_meter = utils.RunningAverageMeter(0.97) 585 | ce_meter = utils.RunningAverageMeter(0.97) 586 | 587 | 588 | def train(epoch, model): 589 | 590 | model = parallelize(model) 591 | model.train() 592 | 593 | total = 0 594 | correct = 0 595 | 596 | end = time.time() 597 | 598 | for i, (x, y) in enumerate(train_loader): 599 | 600 | global_itr = epoch * len(train_loader) + i 601 | update_lr(optimizer, global_itr) 602 | 603 | # Training procedure: 604 | # for each sample x: 605 | # compute z = f(x) 606 | # maximize log p(x) = log p(z) - log |det df/dx| 607 | 608 | x = x.to(device) 609 | 610 | beta = beta = min(1, global_itr / args.annealing_iters) if args.annealing_iters > 0 else 1. 611 | bpd, logits, logpz, neg_delta_logp = compute_loss(x, model, beta=beta) 612 | 613 | if args.task in ['density', 'hybrid']: 614 | firmom, secmom = estimator_moments(model) 615 | 616 | bpd_meter.update(bpd.item()) 617 | logpz_meter.update(logpz.item()) 618 | deltalogp_meter.update(neg_delta_logp.item()) 619 | firmom_meter.update(firmom) 620 | secmom_meter.update(secmom) 621 | 622 | if args.task in ['classification', 'hybrid']: 623 | y = y.to(device) 624 | crossent = criterion(logits, y) 625 | ce_meter.update(crossent.item()) 626 | 627 | # Compute accuracy. 628 | _, predicted = logits.max(1) 629 | total += y.size(0) 630 | correct += predicted.eq(y).sum().item() 631 | 632 | # compute gradient and do SGD step 633 | if args.task == 'density': 634 | loss = bpd 635 | elif args.task == 'classification': 636 | loss = crossent 637 | else: 638 | if not args.scale_dim: bpd = bpd * (args.imagesize * args.imagesize * im_dim) 639 | loss = bpd + crossent / np.log(2) # Change cross entropy from nats to bits. 640 | loss.backward() 641 | 642 | if global_itr % args.update_freq == args.update_freq - 1: 643 | 644 | if args.update_freq > 1: 645 | with torch.no_grad(): 646 | for p in model.parameters(): 647 | if p.grad is not None: 648 | p.grad /= args.update_freq 649 | 650 | grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), 1.) 651 | if args.learn_p: compute_p_grads(model) 652 | 653 | optimizer.step() 654 | optimizer.zero_grad() 655 | update_lipschitz(model) 656 | ema.apply() 657 | 658 | gnorm_meter.update(grad_norm) 659 | 660 | # measure elapsed time 661 | batch_time.update(time.time() - end) 662 | end = time.time() 663 | 664 | if i % args.print_freq == 0: 665 | s = ( 666 | 'Epoch: [{0}][{1}/{2}] | Time {batch_time.val:.3f} | ' 667 | 'GradNorm {gnorm_meter.avg:.2f}'.format( 668 | epoch, i, len(train_loader), batch_time=batch_time, gnorm_meter=gnorm_meter 669 | ) 670 | ) 671 | 672 | if args.task in ['density', 'hybrid']: 673 | s += ( 674 | ' | Bits/dim {bpd_meter.val:.4f}({bpd_meter.avg:.4f}) | ' 675 | 'Logpz {logpz_meter.avg:.0f} | ' 676 | '-DeltaLogp {deltalogp_meter.avg:.0f} | ' 677 | 'EstMoment ({firmom_meter.avg:.0f},{secmom_meter.avg:.0f})'.format( 678 | bpd_meter=bpd_meter, logpz_meter=logpz_meter, deltalogp_meter=deltalogp_meter, 679 | firmom_meter=firmom_meter, secmom_meter=secmom_meter 680 | ) 681 | ) 682 | 683 | if args.task in ['classification', 'hybrid']: 684 | s += ' | CE {ce_meter.avg:.4f} | Acc {0:.4f}'.format(100 * correct / total, ce_meter=ce_meter) 685 | 686 | logger.info(s) 687 | if i % args.vis_freq == 0: 688 | visualize(epoch, model, i, x) 689 | 690 | del x 691 | torch.cuda.empty_cache() 692 | gc.collect() 693 | 694 | 695 | def validate(epoch, model, ema=None): 696 | """ 697 | Evaluates the cross entropy between p_data and p_model. 698 | """ 699 | bpd_meter = utils.AverageMeter() 700 | ce_meter = utils.AverageMeter() 701 | 702 | if ema is not None: 703 | ema.swap() 704 | 705 | update_lipschitz(model) 706 | 707 | model = parallelize(model) 708 | model.eval() 709 | 710 | correct = 0 711 | total = 0 712 | 713 | start = time.time() 714 | with torch.no_grad(): 715 | for i, (x, y) in enumerate(tqdm(test_loader)): 716 | x = x.to(device) 717 | bpd, logits, _, _ = compute_loss(x, model) 718 | bpd_meter.update(bpd.item(), x.size(0)) 719 | 720 | if args.task in ['classification', 'hybrid']: 721 | y = y.to(device) 722 | loss = criterion(logits, y) 723 | ce_meter.update(loss.item(), x.size(0)) 724 | _, predicted = logits.max(1) 725 | total += y.size(0) 726 | correct += predicted.eq(y).sum().item() 727 | val_time = time.time() - start 728 | 729 | if ema is not None: 730 | ema.swap() 731 | s = 'Epoch: [{0}]\tTime {1:.2f} | Test bits/dim {bpd_meter.avg:.4f}'.format(epoch, val_time, bpd_meter=bpd_meter) 732 | if args.task in ['classification', 'hybrid']: 733 | s += ' | CE {:.4f} | Acc {:.2f}'.format(ce_meter.avg, 100 * correct / total) 734 | logger.info(s) 735 | return bpd_meter.avg 736 | 737 | 738 | def visualize(epoch, model, itr, real_imgs): 739 | model.eval() 740 | utils.makedirs(os.path.join(args.save, 'imgs')) 741 | real_imgs = real_imgs[:32] 742 | _real_imgs = real_imgs 743 | 744 | if args.data == 'celeba_5bit': 745 | nvals = 32 746 | elif args.data == 'celebahq': 747 | nvals = 2**args.nbits 748 | else: 749 | nvals = 256 750 | 751 | with torch.no_grad(): 752 | # reconstructed real images 753 | real_imgs, _ = add_padding(real_imgs, nvals) 754 | if args.squeeze_first: real_imgs = squeeze_layer(real_imgs) 755 | recon_imgs = model(model(real_imgs.view(-1, *input_size[1:])), inverse=True).view(-1, *input_size[1:]) 756 | if args.squeeze_first: recon_imgs = squeeze_layer.inverse(recon_imgs) 757 | recon_imgs = remove_padding(recon_imgs) 758 | 759 | # random samples 760 | fake_imgs = model(fixed_z, inverse=True).view(-1, *input_size[1:]) 761 | if args.squeeze_first: fake_imgs = squeeze_layer.inverse(fake_imgs) 762 | fake_imgs = remove_padding(fake_imgs) 763 | 764 | fake_imgs = fake_imgs.view(-1, im_dim, args.imagesize, args.imagesize) 765 | recon_imgs = recon_imgs.view(-1, im_dim, args.imagesize, args.imagesize) 766 | imgs = torch.cat([_real_imgs, fake_imgs, recon_imgs], 0) 767 | 768 | filename = os.path.join(args.save, 'imgs', 'e{:03d}_i{:06d}.png'.format(epoch, itr)) 769 | save_image(imgs.cpu().float(), filename, nrow=16, padding=2) 770 | model.train() 771 | 772 | 773 | def get_lipschitz_constants(model): 774 | lipschitz_constants = [] 775 | for m in model.modules(): 776 | if isinstance(m, base_layers.SpectralNormConv2d) or isinstance(m, base_layers.SpectralNormLinear): 777 | lipschitz_constants.append(m.scale) 778 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear): 779 | lipschitz_constants.append(m.scale) 780 | if isinstance(m, base_layers.LopConv2d) or isinstance(m, base_layers.LopLinear): 781 | lipschitz_constants.append(m.scale) 782 | return lipschitz_constants 783 | 784 | 785 | def update_lipschitz(model): 786 | with torch.no_grad(): 787 | for m in model.modules(): 788 | if isinstance(m, base_layers.SpectralNormConv2d) or isinstance(m, base_layers.SpectralNormLinear): 789 | m.compute_weight(update=True) 790 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear): 791 | m.compute_weight(update=True) 792 | 793 | 794 | def get_ords(model): 795 | ords = [] 796 | for m in model.modules(): 797 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear): 798 | domain, codomain = m.compute_domain_codomain() 799 | if torch.is_tensor(domain): 800 | domain = domain.item() 801 | if torch.is_tensor(codomain): 802 | codomain = codomain.item() 803 | ords.append(domain) 804 | ords.append(codomain) 805 | return ords 806 | 807 | 808 | def pretty_repr(a): 809 | return '[[' + ','.join(list(map(lambda i: f'{i:.2f}', a))) + ']]' 810 | 811 | 812 | def main(): 813 | global best_test_bpd 814 | 815 | last_checkpoints = [] 816 | lipschitz_constants = [] 817 | ords = [] 818 | 819 | # if args.resume: 820 | # validate(args.begin_epoch - 1, model, ema) 821 | for epoch in range(args.begin_epoch, args.nepochs): 822 | 823 | logger.info('Current LR {}'.format(optimizer.param_groups[0]['lr'])) 824 | 825 | train(epoch, model) 826 | lipschitz_constants.append(get_lipschitz_constants(model)) 827 | ords.append(get_ords(model)) 828 | logger.info('Lipsh: {}'.format(pretty_repr(lipschitz_constants[-1]))) 829 | logger.info('Order: {}'.format(pretty_repr(ords[-1]))) 830 | 831 | if args.ema_val: 832 | test_bpd = validate(epoch, model, ema) 833 | else: 834 | test_bpd = validate(epoch, model) 835 | 836 | if args.scheduler and scheduler is not None: 837 | scheduler.step() 838 | 839 | if test_bpd < best_test_bpd: 840 | best_test_bpd = test_bpd 841 | utils.save_checkpoint({ 842 | 'state_dict': model.state_dict(), 843 | 'optimizer_state_dict': optimizer.state_dict(), 844 | 'args': args, 845 | 'ema': ema, 846 | 'test_bpd': test_bpd, 847 | }, os.path.join(args.save, 'models'), epoch, last_checkpoints, num_checkpoints=5) 848 | 849 | torch.save({ 850 | 'state_dict': model.state_dict(), 851 | 'optimizer_state_dict': optimizer.state_dict(), 852 | 'args': args, 853 | 'ema': ema, 854 | 'test_bpd': test_bpd, 855 | }, os.path.join(args.save, 'models', 'most_recent.pth')) 856 | 857 | 858 | if __name__ == '__main__': 859 | main() 860 | -------------------------------------------------------------------------------- /train_toy.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | 5 | import argparse 6 | import os 7 | import time 8 | import math 9 | import numpy as np 10 | 11 | import torch 12 | 13 | import resflows.optimizers as optim 14 | import resflows.layers.base as base_layers 15 | import resflows.layers as layers 16 | import resflows.toy_data as toy_data 17 | import resflows.utils as utils 18 | from resflows.visualize_flow import visualize_transform 19 | 20 | ACTIVATION_FNS = { 21 | 'relu': torch.nn.ReLU, 22 | 'tanh': torch.nn.Tanh, 23 | 'elu': torch.nn.ELU, 24 | 'selu': torch.nn.SELU, 25 | 'fullsort': base_layers.FullSort, 26 | 'maxmin': base_layers.MaxMin, 27 | 'swish': base_layers.Swish, 28 | 'lcube': base_layers.LipschitzCube, 29 | } 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument( 33 | '--data', choices=['swissroll', '8gaussians', 'pinwheel', 'circles', 'moons', '2spirals', 'checkerboard', 'rings'], 34 | type=str, default='pinwheel' 35 | ) 36 | parser.add_argument('--arch', choices=['iresnet', 'realnvp'], default='iresnet') 37 | parser.add_argument('--coeff', type=float, default=0.9) 38 | parser.add_argument('--vnorms', type=str, default='222222') 39 | parser.add_argument('--n-lipschitz-iters', type=int, default=5) 40 | parser.add_argument('--atol', type=float, default=None) 41 | parser.add_argument('--rtol', type=float, default=None) 42 | parser.add_argument('--learn-p', type=eval, choices=[True, False], default=False) 43 | parser.add_argument('--mixed', type=eval, choices=[True, False], default=True) 44 | 45 | parser.add_argument('--dims', type=str, default='128-128-128-128') 46 | parser.add_argument('--act', type=str, choices=ACTIVATION_FNS.keys(), default='swish') 47 | parser.add_argument('--nblocks', type=int, default=100) 48 | parser.add_argument('--brute-force', type=eval, choices=[True, False], default=False) 49 | parser.add_argument('--actnorm', type=eval, choices=[True, False], default=False) 50 | parser.add_argument('--batchnorm', type=eval, choices=[True, False], default=False) 51 | parser.add_argument('--exact-trace', type=eval, choices=[True, False], default=False) 52 | parser.add_argument('--n-power-series', type=int, default=None) 53 | parser.add_argument('--n-samples', type=int, default=1) 54 | parser.add_argument('--n-dist', choices=['geometric', 'poisson'], default='geometric') 55 | 56 | parser.add_argument('--niters', type=int, default=50000) 57 | parser.add_argument('--batch_size', type=int, default=500) 58 | parser.add_argument('--test_batch_size', type=int, default=10000) 59 | parser.add_argument('--lr', type=float, default=1e-3) 60 | parser.add_argument('--weight-decay', type=float, default=1e-5) 61 | parser.add_argument('--annealing-iters', type=int, default=0) 62 | 63 | parser.add_argument('--save', type=str, default='experiments/iresnet_toy') 64 | parser.add_argument('--viz_freq', type=int, default=100) 65 | parser.add_argument('--val_freq', type=int, default=100) 66 | parser.add_argument('--log_freq', type=int, default=10) 67 | parser.add_argument('--gpu', type=int, default=0) 68 | parser.add_argument('--seed', type=int, default=0) 69 | args = parser.parse_args() 70 | 71 | # logger 72 | utils.makedirs(args.save) 73 | logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) 74 | logger.info(args) 75 | 76 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') 77 | 78 | np.random.seed(args.seed) 79 | torch.manual_seed(args.seed) 80 | if device.type == 'cuda': 81 | torch.cuda.manual_seed(args.seed) 82 | 83 | 84 | def count_parameters(model): 85 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 86 | 87 | 88 | def standard_normal_sample(size): 89 | return torch.randn(size) 90 | 91 | 92 | def standard_normal_logprob(z): 93 | logZ = -0.5 * math.log(2 * math.pi) 94 | return logZ - z.pow(2) / 2 95 | 96 | 97 | def compute_loss(args, model, batch_size=None, beta=1.): 98 | if batch_size is None: batch_size = args.batch_size 99 | 100 | # load data 101 | x = toy_data.inf_train_gen(args.data, batch_size=batch_size) 102 | x = torch.from_numpy(x).type(torch.float32).to(device) 103 | zero = torch.zeros(x.shape[0], 1).to(x) 104 | 105 | # transform to z 106 | z, delta_logp = model(x, zero) 107 | 108 | # compute log p(z) 109 | logpz = standard_normal_logprob(z).sum(1, keepdim=True) 110 | 111 | logpx = logpz - beta * delta_logp 112 | loss = -torch.mean(logpx) 113 | return loss, torch.mean(logpz), torch.mean(-delta_logp) 114 | 115 | 116 | def parse_vnorms(): 117 | ps = [] 118 | for p in args.vnorms: 119 | if p == 'f': 120 | ps.append(float('inf')) 121 | else: 122 | ps.append(float(p)) 123 | return ps[:-1], ps[1:] 124 | 125 | 126 | def compute_p_grads(model): 127 | scales = 0. 128 | nlayers = 0 129 | for m in model.modules(): 130 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear): 131 | scales = scales + m.compute_one_iter() 132 | nlayers += 1 133 | scales.mul(1 / nlayers).mul(0.01).backward() 134 | for m in model.modules(): 135 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear): 136 | if m.domain.grad is not None and torch.isnan(m.domain.grad): 137 | m.domain.grad = None 138 | 139 | 140 | def build_nnet(dims, activation_fn=torch.nn.ReLU): 141 | nnet = [] 142 | domains, codomains = parse_vnorms() 143 | if args.learn_p: 144 | if args.mixed: 145 | domains = [torch.nn.Parameter(torch.tensor(0.)) for _ in domains] 146 | else: 147 | domains = [torch.nn.Parameter(torch.tensor(0.))] * len(domains) 148 | codomains = domains[1:] + [domains[0]] 149 | for i, (in_dim, out_dim, domain, codomain) in enumerate(zip(dims[:-1], dims[1:], domains, codomains)): 150 | nnet.append(activation_fn()) 151 | nnet.append( 152 | base_layers.get_linear( 153 | in_dim, 154 | out_dim, 155 | coeff=args.coeff, 156 | n_iterations=args.n_lipschitz_iters, 157 | atol=args.atol, 158 | rtol=args.rtol, 159 | domain=domain, 160 | codomain=codomain, 161 | zero_init=(out_dim == 2), 162 | ) 163 | ) 164 | return torch.nn.Sequential(*nnet) 165 | 166 | 167 | def update_lipschitz(model, n_iterations): 168 | for m in model.modules(): 169 | if isinstance(m, base_layers.SpectralNormConv2d) or isinstance(m, base_layers.SpectralNormLinear): 170 | m.compute_weight(update=True, n_iterations=n_iterations) 171 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear): 172 | m.compute_weight(update=True, n_iterations=n_iterations) 173 | 174 | 175 | def get_ords(model): 176 | ords = [] 177 | for m in model.modules(): 178 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear): 179 | domain, codomain = m.compute_domain_codomain() 180 | if torch.is_tensor(domain): 181 | domain = domain.item() 182 | if torch.is_tensor(codomain): 183 | codomain = codomain.item() 184 | ords.append(domain) 185 | ords.append(codomain) 186 | return ords 187 | 188 | 189 | def pretty_repr(a): 190 | return '[[' + ','.join(list(map(lambda i: f'{i:.2f}', a))) + ']]' 191 | 192 | 193 | if __name__ == '__main__': 194 | 195 | activation_fn = ACTIVATION_FNS[args.act] 196 | 197 | if args.arch == 'iresnet': 198 | dims = [2] + list(map(int, args.dims.split('-'))) + [2] 199 | blocks = [] 200 | if args.actnorm: blocks.append(layers.ActNorm1d(2)) 201 | for _ in range(args.nblocks): 202 | blocks.append( 203 | layers.iResBlock( 204 | build_nnet(dims, activation_fn), 205 | n_dist=args.n_dist, 206 | n_power_series=args.n_power_series, 207 | exact_trace=args.exact_trace, 208 | brute_force=args.brute_force, 209 | n_samples=args.n_samples, 210 | neumann_grad=False, 211 | grad_in_forward=False, 212 | ) 213 | ) 214 | if args.actnorm: blocks.append(layers.ActNorm1d(2)) 215 | if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2)) 216 | model = layers.SequentialFlow(blocks).to(device) 217 | elif args.arch == 'realnvp': 218 | blocks = [] 219 | for _ in range(args.nblocks): 220 | blocks.append(layers.CouplingBlock(2, swap=False)) 221 | blocks.append(layers.CouplingBlock(2, swap=True)) 222 | if args.actnorm: blocks.append(layers.ActNorm1d(2)) 223 | if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2)) 224 | model = layers.SequentialFlow(blocks).to(device) 225 | 226 | logger.info(model) 227 | logger.info("Number of trainable parameters: {}".format(count_parameters(model))) 228 | 229 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 230 | 231 | time_meter = utils.RunningAverageMeter(0.93) 232 | loss_meter = utils.RunningAverageMeter(0.93) 233 | logpz_meter = utils.RunningAverageMeter(0.93) 234 | delta_logp_meter = utils.RunningAverageMeter(0.93) 235 | 236 | end = time.time() 237 | best_loss = float('inf') 238 | model.train() 239 | for itr in range(1, args.niters + 1): 240 | optimizer.zero_grad() 241 | 242 | beta = min(1, itr / args.annealing_iters) if args.annealing_iters > 0 else 1. 243 | loss, logpz, delta_logp = compute_loss(args, model, beta=beta) 244 | loss_meter.update(loss.item()) 245 | logpz_meter.update(logpz.item()) 246 | delta_logp_meter.update(delta_logp.item()) 247 | loss.backward() 248 | if args.learn_p and itr > args.annealing_iters: compute_p_grads(model) 249 | optimizer.step() 250 | update_lipschitz(model, args.n_lipschitz_iters) 251 | 252 | time_meter.update(time.time() - end) 253 | 254 | logger.info( 255 | 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f})' 256 | ' | Logp(z) {:.6f}({:.6f}) | DeltaLogp {:.6f}({:.6f})'.format( 257 | itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, logpz_meter.val, logpz_meter.avg, 258 | delta_logp_meter.val, delta_logp_meter.avg 259 | ) 260 | ) 261 | 262 | if itr % args.val_freq == 0 or itr == args.niters: 263 | update_lipschitz(model, 200) 264 | with torch.no_grad(): 265 | model.eval() 266 | test_loss, test_logpz, test_delta_logp = compute_loss(args, model, batch_size=args.test_batch_size) 267 | log_message = ( 268 | '[TEST] Iter {:04d} | Test Loss {:.6f} ' 269 | '| Test Logp(z) {:.6f} | Test DeltaLogp {:.6f}'.format( 270 | itr, test_loss.item(), test_logpz.item(), test_delta_logp.item() 271 | ) 272 | ) 273 | logger.info(log_message) 274 | 275 | logger.info('Ords: {}'.format(pretty_repr(get_ords(model)))) 276 | 277 | if test_loss.item() < best_loss: 278 | best_loss = test_loss.item() 279 | utils.makedirs(args.save) 280 | torch.save({ 281 | 'args': args, 282 | 'state_dict': model.state_dict(), 283 | }, os.path.join(args.save, 'checkpt.pth')) 284 | model.train() 285 | 286 | if itr == 1 or itr % args.viz_freq == 0: 287 | with torch.no_grad(): 288 | model.eval() 289 | p_samples = toy_data.inf_train_gen(args.data, batch_size=20000) 290 | 291 | sample_fn, density_fn = model.inverse, model.forward 292 | 293 | plt.figure(figsize=(9, 3)) 294 | visualize_transform( 295 | p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn, 296 | samples=True, npts=400, device=device 297 | ) 298 | fig_filename = os.path.join(args.save, 'figs', '{:04d}.jpg'.format(itr)) 299 | utils.makedirs(os.path.dirname(fig_filename)) 300 | plt.savefig(fig_filename) 301 | plt.close() 302 | model.train() 303 | 304 | end = time.time() 305 | 306 | logger.info('Training has finished.') 307 | --------------------------------------------------------------------------------