├── .gitignore
├── LICENSE
├── README.md
├── assets
├── celebahq_resflow.jpg
└── flow_comparison.jpg
├── preprocessing
├── convert_to_pth.py
├── create_imagenet_benchmark_datasets.py
└── extract_celeba_from_tfrecords.py
├── qualitative_samples.py
├── resflows
├── datasets.py
├── layers
│ ├── __init__.py
│ ├── act_norm.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ ├── lipschitz.py
│ │ ├── mixed_lipschitz.py
│ │ └── utils.py
│ ├── container.py
│ ├── coupling.py
│ ├── elemwise.py
│ ├── glow.py
│ ├── iresblock.py
│ ├── mask_utils.py
│ ├── normalization.py
│ └── squeeze.py
├── lr_scheduler.py
├── optimizers.py
├── resflow.py
├── toy_data.py
├── utils.py
└── visualize_flow.py
├── setup.py
├── train_img.py
└── train_toy.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *__pycache__*
3 | data/*
4 | pretrained_models
5 | *egg-info
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Ricky Tian Qi Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Residual Flows for Invertible Generative Modeling [[arxiv](https://arxiv.org/abs/1906.02735)]
2 |
3 |
4 |
5 |
6 |
7 | Building on the use of [Invertible Residual Networks](https://arxiv.org/abs/1811.00995) in generative modeling, we propose:
8 | + Unbiased estimation of the log-density of samples.
9 | + Memory-efficient reformulation of the gradients.
10 | + LipSwish activation function.
11 |
12 | As a result, Residual Flows scale to much larger networks and datasets.
13 |
14 |
15 |
16 |
17 |
18 | ## Requirements
19 |
20 | - PyTorch 1.0+
21 | - Python 3.6+
22 |
23 | ## Package installation (optional)
24 |
25 | If you want to just use the `resflows` package, you can install it through pip:
26 | ```
27 | pip install git+https://github.com/rtqichen/residual-flows
28 | ```
29 |
30 | ## Preprocessing
31 | ImageNet:
32 | 1. Follow instructions in `preprocessing/create_imagenet_benchmark_datasets`.
33 | 2. Convert .npy files to .pth using `preprocessing/convert_to_pth`.
34 | 3. Place in `data/imagenet32` and `data/imagenet64`.
35 |
36 | CelebAHQ 64x64 5bit:
37 |
38 | 1. Download from https://github.com/aravindsrinivas/flowpp/tree/master/flows_celeba.
39 | 2. Convert .npy files to .pth using `preprocessing/convert_to_pth`.
40 | 3. Place in `data/celebahq64_5bit`.
41 |
42 | CelebAHQ 256x256:
43 | ```
44 | # Download Glow's preprocessed dataset.
45 | wget https://storage.googleapis.com/glow-demo/data/celeba-tfr.tar
46 | tar -C data/celebahq -xvf celeb-tfr.tar
47 | python extract_celeba_from_tfrecords
48 | ```
49 |
50 | ## Density Estimation Experiments
51 |
52 | ***NOTE***: By default, O(1)-memory gradients are enabled. However, the logged bits/dim during training will not be an actual estimate of bits/dim but whatever scalar was used to generate the unbiased gradients. If you want to check the actual bits/dim for training (and have sufficient GPU memory), set `--neumann-grad=False`. Note however that the memory cost can stochastically vary during training if this flag is `False`.
53 |
54 | MNIST:
55 | ```
56 | python train_img.py --data mnist --imagesize 28 --actnorm True --wd 0 --save experiments/mnist
57 | ```
58 |
59 | CIFAR10:
60 | ```
61 | python train_img.py --data cifar10 --actnorm True --save experiments/cifar10
62 | ```
63 |
64 | ImageNet 32x32:
65 | ```
66 | python train_img.py --data imagenet32 --actnorm True --nblocks 32-32-32 --save experiments/imagenet32
67 | ```
68 |
69 | ImageNet 64x64:
70 | ```
71 | python train_img.py --data imagenet64 --imagesize 64 --actnorm True --nblocks 32-32-32 --factor-out True --squeeze-first True --save experiments/imagenet64
72 | ```
73 |
74 | CelebAHQ 256x256:
75 | ```
76 | python train_img.py --data celebahq --imagesize 256 --nbits 5 --actnorm True --act elu --batchsize 8 --update-freq 5 --n-exact-terms 8 --fc-end False --factor-out True --squeeze-first True --nblocks 16-16-16-16-16-16 --save experiments/celebahq256
77 | ```
78 |
79 | ## Pretrained Models
80 |
81 | Model checkpoints can be downloaded from [releases](https://github.com/rtqichen/residual-flows/releases/latest).
82 |
83 | Use the argument `--resume [checkpt.pth]` to evaluate or sample from the model.
84 |
85 | Each checkpoint contains two sets of parameters, one from training and one containing the exponential moving average (EMA) accumulated over the course of training. Scripts will automatically use the EMA parameters for evaluation and sampling.
86 |
87 | ## BibTeX
88 | ```
89 | @inproceedings{chen2019residualflows,
90 | title={Residual Flows for Invertible Generative Modeling},
91 | author={Chen, Ricky T. Q. and Behrmann, Jens and Duvenaud, David and Jacobsen, J{\"{o}}rn{-}Henrik},
92 | booktitle = {Advances in Neural Information Processing Systems},
93 | year={2019}
94 | }
95 | ```
96 |
--------------------------------------------------------------------------------
/assets/celebahq_resflow.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rtqichen/residual-flows/8170138c850a3574319491d97093bc860ce4922d/assets/celebahq_resflow.jpg
--------------------------------------------------------------------------------
/assets/flow_comparison.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rtqichen/residual-flows/8170138c850a3574319491d97093bc860ce4922d/assets/flow_comparison.jpg
--------------------------------------------------------------------------------
/preprocessing/convert_to_pth.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import re
3 | import numpy as np
4 | import torch
5 |
6 | img = torch.tensor(np.load(sys.argv[1]))
7 | img = img.permute(0, 3, 1, 2)
8 | torch.save(img, re.sub('.npy$', '.pth', sys.argv[1]))
9 |
--------------------------------------------------------------------------------
/preprocessing/create_imagenet_benchmark_datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Run the following commands in ~ before running this file
3 | wget http://image-net.org/small/train_64x64.tar
4 | wget http://image-net.org/small/valid_64x64.tar
5 | tar -xvf train_64x64.tar
6 | tar -xvf valid_64x64.tar
7 | wget http://image-net.org/small/train_32x32.tar
8 | wget http://image-net.org/small/valid_32x32.tar
9 | tar -xvf train_32x32.tar
10 | tar -xvf valid_32x32.tar
11 | """
12 |
13 | import numpy as np
14 | import scipy.ndimage
15 | import os
16 | from os import listdir
17 | from os.path import isfile, join
18 | from tqdm import tqdm
19 |
20 |
21 | def convert_path_to_npy(*, path='train_64x64', outfile='train_64x64.npy'):
22 | assert isinstance(path, str), "Expected a string input for the path"
23 | assert os.path.exists(path), "Input path doesn't exist"
24 | files = [f for f in listdir(path) if isfile(join(path, f))]
25 | print('Number of valid images is:', len(files))
26 | imgs = []
27 | for i in tqdm(range(len(files))):
28 | img = scipy.ndimage.imread(join(path, files[i]))
29 | img = img.astype('uint8')
30 | assert np.max(img) <= 255
31 | assert np.min(img) >= 0
32 | assert img.dtype == 'uint8'
33 | assert isinstance(img, np.ndarray)
34 | imgs.append(img)
35 | resolution_x, resolution_y = img.shape[0], img.shape[1]
36 | imgs = np.asarray(imgs).astype('uint8')
37 | assert imgs.shape[1:] == (resolution_x, resolution_y, 3)
38 | assert np.max(imgs) <= 255
39 | assert np.min(imgs) >= 0
40 | print('Total number of images is:', imgs.shape[0])
41 | print('All assertions done, dumping into npy file')
42 | np.save(outfile, imgs)
43 |
44 |
45 | if __name__ == '__main__':
46 | convert_path_to_npy(path='train_64x64', outfile='train_64x64.npy')
47 | convert_path_to_npy(path='valid_64x64', outfile='valid_64x64.npy')
48 | convert_path_to_npy(path='train_32x32', outfile='train_32x32.npy')
49 | convert_path_to_npy(path='valid_32x32', outfile='valid_32x32.npy')
50 |
--------------------------------------------------------------------------------
/preprocessing/extract_celeba_from_tfrecords.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import torch
4 |
5 | sess = tf.InteractiveSession()
6 |
7 | train_imgs = []
8 |
9 | print('Reading from training set...', flush=True)
10 | for i in range(120):
11 | tfr = 'data/celebahq/celeba-tfr/train/train-r08-s-{:04d}-of-0120.tfrecords'.format(i)
12 | print(tfr, flush=True)
13 |
14 | record_iterator = tf.python_io.tf_record_iterator(tfr)
15 |
16 | for string_record in record_iterator:
17 | example = tf.train.Example()
18 | example.ParseFromString(string_record)
19 |
20 | image_bytes = example.features.feature['data'].bytes_list.value[0]
21 |
22 | img = tf.decode_raw(image_bytes, tf.uint8)
23 | img = tf.reshape(img, [256, 256, 3])
24 | img = img.eval()
25 |
26 | train_imgs.append(img)
27 |
28 | train_imgs = np.stack(train_imgs)
29 | train_imgs = torch.tensor(train_imgs).permute(0, 3, 1, 2)
30 | torch.save(train_imgs, 'data/celebahq/celeba256_train.pth')
31 |
32 | validation_imgs = []
33 | for i in range(40):
34 | tfr = 'data/celebahq/celeba-tfr/validation/validation-r08-s-{:04d}-of-0040.tfrecords'.format(i)
35 | print(tfr, flush=True)
36 |
37 | record_iterator = tf.python_io.tf_record_iterator(tfr)
38 |
39 | for string_record in record_iterator:
40 | example = tf.train.Example()
41 | example.ParseFromString(string_record)
42 |
43 | image_bytes = example.features.feature['data'].bytes_list.value[0]
44 |
45 | img = tf.decode_raw(image_bytes, tf.uint8)
46 | img = tf.reshape(img, [256, 256, 3])
47 | img = img.eval()
48 |
49 | validation_imgs.append(img)
50 |
51 | validation_imgs = np.stack(validation_imgs)
52 | validation_imgs = torch.tensor(validation_imgs).permute(0, 3, 1, 2)
53 | torch.save(validation_imgs, 'data/celebahq/celeba256_validation.pth')
54 |
--------------------------------------------------------------------------------
/qualitative_samples.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | from tqdm import tqdm
4 |
5 | import torch
6 | import torchvision.transforms as transforms
7 | from torchvision.utils import save_image
8 | import torchvision.datasets as vdsets
9 |
10 | from resflows.iresnet import ACT_FNS, ResidualFlow
11 | import resflows.datasets as datasets
12 | import resflows.utils as utils
13 | import resflows.layers as layers
14 | import resflows.layers.base as base_layers
15 |
16 | # Arguments
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument(
19 | '--data', type=str, default='cifar10', choices=[
20 | 'mnist',
21 | 'cifar10',
22 | 'celeba',
23 | 'celebahq',
24 | 'celeba_5bit',
25 | 'imagenet32',
26 | 'imagenet64',
27 | ]
28 | )
29 | parser.add_argument('--dataroot', type=str, default='data')
30 | parser.add_argument('--imagesize', type=int, default=32)
31 | parser.add_argument('--nbits', type=int, default=8) # Only used for celebahq.
32 |
33 | # Sampling parameters.
34 | parser.add_argument('--real', type=eval, choices=[True, False], default=False)
35 | parser.add_argument('--nrow', type=int, default=10)
36 | parser.add_argument('--ncol', type=int, default=10)
37 | parser.add_argument('--temp', type=float, default=1.0)
38 | parser.add_argument('--nbatches', type=int, default=5)
39 | parser.add_argument('--save-each', type=eval, choices=[True, False], default=False)
40 |
41 | parser.add_argument('--block', type=str, choices=['resblock', 'coupling'], default='resblock')
42 |
43 | parser.add_argument('--coeff', type=float, default=0.98)
44 | parser.add_argument('--vnorms', type=str, default='2222')
45 | parser.add_argument('--n-lipschitz-iters', type=int, default=None)
46 | parser.add_argument('--sn-tol', type=float, default=1e-3)
47 | parser.add_argument('--learn-p', type=eval, choices=[True, False], default=False)
48 |
49 | parser.add_argument('--n-power-series', type=int, default=None)
50 | parser.add_argument('--factor-out', type=eval, choices=[True, False], default=False)
51 | parser.add_argument('--n-dist', choices=['geometric', 'poisson'], default='geometric')
52 | parser.add_argument('--n-samples', type=int, default=1)
53 | parser.add_argument('--n-exact-terms', type=int, default=2)
54 | parser.add_argument('--var-reduc-lr', type=float, default=0)
55 | parser.add_argument('--neumann-grad', type=eval, choices=[True, False], default=True)
56 | parser.add_argument('--mem-eff', type=eval, choices=[True, False], default=True)
57 |
58 | parser.add_argument('--act', type=str, choices=ACT_FNS.keys(), default='swish')
59 | parser.add_argument('--idim', type=int, default=512)
60 | parser.add_argument('--nblocks', type=str, default='16-16-16')
61 | parser.add_argument('--squeeze-first', type=eval, default=False, choices=[True, False])
62 | parser.add_argument('--actnorm', type=eval, default=True, choices=[True, False])
63 | parser.add_argument('--fc-actnorm', type=eval, default=False, choices=[True, False])
64 | parser.add_argument('--batchnorm', type=eval, default=False, choices=[True, False])
65 | parser.add_argument('--dropout', type=float, default=0.)
66 | parser.add_argument('--fc', type=eval, default=False, choices=[True, False])
67 | parser.add_argument('--kernels', type=str, default='3-1-3')
68 | parser.add_argument('--quadratic', type=eval, choices=[True, False], default=False)
69 | parser.add_argument('--fc-end', type=eval, choices=[True, False], default=True)
70 | parser.add_argument('--fc-idim', type=int, default=128)
71 | parser.add_argument('--preact', type=eval, choices=[True, False], default=True)
72 | parser.add_argument('--padding', type=int, default=0)
73 | parser.add_argument('--first-resblock', type=eval, choices=[True, False], default=True)
74 |
75 | parser.add_argument('--task', type=str, choices=['density'], default='density')
76 | parser.add_argument('--rcrop-pad-mode', type=str, choices=['constant', 'reflect'], default='reflect')
77 | parser.add_argument('--padding-dist', type=str, choices=['uniform', 'gaussian'], default='uniform')
78 |
79 | parser.add_argument('--resume', type=str, required=True)
80 | parser.add_argument('--nworkers', type=int, default=4)
81 | args = parser.parse_args()
82 |
83 | W = args.ncol
84 | H = args.nrow
85 |
86 | args.batchsize = W * H
87 | args.val_batchsize = W * H
88 |
89 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
90 |
91 | if device.type == 'cuda':
92 | print('Found {} CUDA devices.'.format(torch.cuda.device_count()))
93 | for i in range(torch.cuda.device_count()):
94 | props = torch.cuda.get_device_properties(i)
95 | print('{} \t Memory: {:.2f}GB'.format(props.name, props.total_memory / (1024**3)))
96 | else:
97 | print('WARNING: Using device {}'.format(device))
98 |
99 |
100 | def geometric_logprob(ns, p):
101 | return torch.log(1 - p + 1e-10) * (ns - 1) + torch.log(p + 1e-10)
102 |
103 |
104 | def standard_normal_sample(size):
105 | return torch.randn(size)
106 |
107 |
108 | def standard_normal_logprob(z):
109 | logZ = -0.5 * math.log(2 * math.pi)
110 | return logZ - z.pow(2) / 2
111 |
112 |
113 | def count_parameters(model):
114 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
115 |
116 |
117 | def add_noise(x, nvals=255):
118 | """
119 | [0, 1] -> [0, nvals] -> add noise -> [0, 1]
120 | """
121 | noise = x.new().resize_as_(x).uniform_()
122 | x = x * nvals + noise
123 | x = x / (nvals + 1)
124 | return x
125 |
126 |
127 | def update_lr(optimizer, itr):
128 | iter_frac = min(float(itr + 1) / max(args.warmup_iters, 1), 1.0)
129 | lr = args.lr * iter_frac
130 | for param_group in optimizer.param_groups:
131 | param_group["lr"] = lr
132 |
133 |
134 | def add_padding(x):
135 | if args.padding > 0:
136 | u = x.new_empty(x.shape[0], args.padding, x.shape[2], x.shape[3]).uniform_()
137 | logpu = torch.zeros_like(u).sum([1, 2, 3])
138 | return torch.cat([u, x], dim=1), logpu
139 | else:
140 | return x, torch.zeros(x.shape[0]).to(x)
141 |
142 |
143 | def remove_padding(x):
144 | if args.padding > 0:
145 | return x[:, args.padding:, :, :]
146 | else:
147 | return x
148 |
149 |
150 | def reduce_bits(x):
151 | if args.nbits < 8:
152 | x = x * 255
153 | x = torch.floor(x / 2**(8 - args.nbits))
154 | x = x / 2**args.nbits
155 | return x
156 |
157 |
158 | def update_lipschitz(model):
159 | for m in model.modules():
160 | if isinstance(m, base_layers.SpectralNormConv2d) or isinstance(m, base_layers.SpectralNormLinear):
161 | m.compute_weight(update=True)
162 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
163 | m.compute_weight(update=True)
164 |
165 |
166 | print('Loading dataset {}'.format(args.data), flush=True)
167 | # Dataset and hyperparameters
168 | if args.data == 'cifar10':
169 | im_dim = 3
170 | n_classes = 10
171 |
172 | if args.task in ['classification', 'hybrid']:
173 |
174 | if args.real:
175 |
176 | # Classification-specific preprocessing.
177 | transform_train = transforms.Compose([
178 | transforms.Resize(args.imagesize),
179 | transforms.RandomCrop(32, padding=4, padding_mode=args.rcrop_pad_mode),
180 | transforms.RandomHorizontalFlip(),
181 | transforms.ToTensor(),
182 | add_noise,
183 | ])
184 |
185 | transform_test = transforms.Compose([
186 | transforms.Resize(args.imagesize),
187 | transforms.ToTensor(),
188 | add_noise,
189 | ])
190 |
191 | # Remove the logit transform.
192 | init_layer = layers.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
193 | else:
194 | if args.real:
195 | transform_train = transforms.Compose([
196 | transforms.Resize(args.imagesize),
197 | transforms.RandomHorizontalFlip(),
198 | transforms.ToTensor(),
199 | add_noise,
200 | ])
201 | transform_test = transforms.Compose([
202 | transforms.Resize(args.imagesize),
203 | transforms.ToTensor(),
204 | add_noise,
205 | ])
206 | init_layer = layers.LogitTransform(0.05)
207 | if args.real:
208 | train_loader = torch.utils.data.DataLoader(
209 | datasets.CIFAR10(args.dataroot, train=True, transform=transform_train),
210 | batch_size=args.batchsize,
211 | shuffle=True,
212 | num_workers=args.nworkers,
213 | )
214 | test_loader = torch.utils.data.DataLoader(
215 | datasets.CIFAR10(args.dataroot, train=False, transform=transform_test),
216 | batch_size=args.val_batchsize,
217 | shuffle=False,
218 | num_workers=args.nworkers,
219 | )
220 | elif args.data == 'mnist':
221 | im_dim = 1
222 | init_layer = layers.LogitTransform(1e-6)
223 | n_classes = 10
224 |
225 | if args.real:
226 | train_loader = torch.utils.data.DataLoader(
227 | datasets.MNIST(
228 | args.dataroot, train=True, transform=transforms.Compose([
229 | transforms.Resize(args.imagesize),
230 | transforms.ToTensor(),
231 | add_noise,
232 | ])
233 | ),
234 | batch_size=args.batchsize,
235 | shuffle=True,
236 | num_workers=args.nworkers,
237 | )
238 | test_loader = torch.utils.data.DataLoader(
239 | datasets.MNIST(
240 | args.dataroot, train=False, transform=transforms.Compose([
241 | transforms.Resize(args.imagesize),
242 | transforms.ToTensor(),
243 | add_noise,
244 | ])
245 | ),
246 | batch_size=args.val_batchsize,
247 | shuffle=False,
248 | num_workers=args.nworkers,
249 | )
250 | elif args.data == 'svhn':
251 | im_dim = 3
252 | init_layer = layers.LogitTransform(0.05)
253 | n_classes = 10
254 |
255 | if args.real:
256 | train_loader = torch.utils.data.DataLoader(
257 | vdsets.SVHN(
258 | args.dataroot, split='train', download=True, transform=transforms.Compose([
259 | transforms.Resize(args.imagesize),
260 | transforms.RandomCrop(32, padding=4, padding_mode=args.rcrop_pad_mode),
261 | transforms.ToTensor(),
262 | add_noise,
263 | ])
264 | ),
265 | batch_size=args.batchsize,
266 | shuffle=True,
267 | num_workers=args.nworkers,
268 | )
269 | test_loader = torch.utils.data.DataLoader(
270 | vdsets.SVHN(
271 | args.dataroot, split='test', download=True, transform=transforms.Compose([
272 | transforms.Resize(args.imagesize),
273 | transforms.ToTensor(),
274 | add_noise,
275 | ])
276 | ),
277 | batch_size=args.val_batchsize,
278 | shuffle=False,
279 | num_workers=args.nworkers,
280 | )
281 | elif args.data == 'celebahq':
282 | im_dim = 3
283 | init_layer = layers.LogitTransform(0.05)
284 |
285 | if args.real:
286 | train_loader = torch.utils.data.DataLoader(
287 | datasets.CelebAHQ(
288 | train=True, transform=transforms.Compose([
289 | transforms.ToPILImage(),
290 | transforms.RandomHorizontalFlip(),
291 | transforms.ToTensor(),
292 | reduce_bits,
293 | lambda x: add_noise(x, nvals=2**args.nbits),
294 | ])
295 | ), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers
296 | )
297 | test_loader = torch.utils.data.DataLoader(
298 | datasets.CelebAHQ(
299 | train=False, transform=transforms.Compose([
300 | reduce_bits,
301 | lambda x: add_noise(x, nvals=2**args.nbits),
302 | ])
303 | ), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers
304 | )
305 | elif args.data == 'celeba_5bit':
306 | im_dim = 3
307 | init_layer = layers.LogitTransform(0.05)
308 | if args.imagesize != 64:
309 | print('Changing image size to 64.')
310 | args.imagesize = 64
311 |
312 | if args.real:
313 | train_loader = torch.utils.data.DataLoader(
314 | datasets.CelebA5bit(
315 | train=True, transform=transforms.Compose([
316 | transforms.ToPILImage(),
317 | transforms.RandomHorizontalFlip(),
318 | transforms.ToTensor(),
319 | lambda x: add_noise(x, nvals=32),
320 | ])
321 | ), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers
322 | )
323 | test_loader = torch.utils.data.DataLoader(
324 | datasets.CelebA5bit(train=False, transform=transforms.Compose([
325 | lambda x: add_noise(x, nvals=32),
326 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers
327 | )
328 | elif args.data == 'imagenet32':
329 | im_dim = 3
330 | init_layer = layers.LogitTransform(0.05)
331 | if args.imagesize != 32:
332 | print('Changing image size to 32.')
333 | args.imagesize = 32
334 |
335 | if args.real:
336 | train_loader = torch.utils.data.DataLoader(
337 | datasets.Imagenet32(train=True, transform=transforms.Compose([
338 | add_noise,
339 | ])), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers
340 | )
341 | test_loader = torch.utils.data.DataLoader(
342 | datasets.Imagenet32(train=False, transform=transforms.Compose([
343 | add_noise,
344 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers
345 | )
346 | elif args.data == 'imagenet64':
347 | im_dim = 3
348 | init_layer = layers.LogitTransform(0.05)
349 | if args.imagesize != 64:
350 | print('Changing image size to 64.')
351 | args.imagesize = 64
352 |
353 | if args.real:
354 | train_loader = torch.utils.data.DataLoader(
355 | datasets.Imagenet64(train=True, transform=transforms.Compose([
356 | add_noise,
357 | ])), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers
358 | )
359 | test_loader = torch.utils.data.DataLoader(
360 | datasets.Imagenet64(train=False, transform=transforms.Compose([
361 | add_noise,
362 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers
363 | )
364 |
365 | if args.task in ['classification', 'hybrid']:
366 | try:
367 | n_classes
368 | except NameError:
369 | raise ValueError('Cannot perform classification with {}'.format(args.data))
370 | else:
371 | n_classes = 1
372 |
373 | print('Dataset loaded.', flush=True)
374 | print('Creating model.', flush=True)
375 |
376 | input_size = (args.batchsize, im_dim + args.padding, args.imagesize, args.imagesize)
377 |
378 | if args.squeeze_first:
379 | input_size = (input_size[0], input_size[1] * 4, input_size[2] // 2, input_size[3] // 2)
380 | squeeze_layer = layers.SqueezeLayer(2)
381 |
382 | # Model
383 | model = ResidualFlow(
384 | input_size,
385 | n_blocks=list(map(int, args.nblocks.split('-'))),
386 | intermediate_dim=args.idim,
387 | factor_out=args.factor_out,
388 | quadratic=args.quadratic,
389 | init_layer=init_layer,
390 | actnorm=args.actnorm,
391 | fc_actnorm=args.fc_actnorm,
392 | batchnorm=args.batchnorm,
393 | dropout=args.dropout,
394 | fc=args.fc,
395 | coeff=args.coeff,
396 | vnorms=args.vnorms,
397 | n_lipschitz_iters=args.n_lipschitz_iters,
398 | sn_atol=args.sn_tol,
399 | sn_rtol=args.sn_tol,
400 | n_power_series=args.n_power_series,
401 | n_dist=args.n_dist,
402 | n_samples=args.n_samples,
403 | kernels=args.kernels,
404 | activation_fn=args.act,
405 | fc_end=args.fc_end,
406 | fc_idim=args.fc_idim,
407 | n_exact_terms=args.n_exact_terms,
408 | preact=args.preact,
409 | neumann_grad=args.neumann_grad,
410 | grad_in_forward=args.mem_eff,
411 | first_resblock=args.first_resblock,
412 | learn_p=args.learn_p,
413 | block_type=args.block,
414 | )
415 |
416 | model.to(device)
417 |
418 | print('Initializing model.', flush=True)
419 |
420 | with torch.no_grad():
421 | x = torch.rand(1, *input_size[1:]).to(device)
422 | model(x)
423 | print('Restoring from checkpoint.', flush=True)
424 | checkpt = torch.load(args.resume)
425 | state = model.state_dict()
426 | model.load_state_dict(checkpt['state_dict'], strict=True)
427 |
428 | ema = utils.ExponentialMovingAverage(model)
429 | ema.set(checkpt['ema'])
430 | ema.swap()
431 |
432 | print(model, flush=True)
433 |
434 | model.eval()
435 | print('Updating lipschitz.', flush=True)
436 | update_lipschitz(model)
437 |
438 |
439 | def visualize(model):
440 | utils.makedirs('{}_imgs_t{}'.format(args.data, args.temp))
441 |
442 | with torch.no_grad():
443 |
444 | for i in tqdm(range(args.nbatches)):
445 | # random samples
446 | rand_z = torch.randn(args.batchsize, (im_dim + args.padding) * args.imagesize * args.imagesize).to(device)
447 | rand_z = rand_z * args.temp
448 | fake_imgs = model(rand_z, inverse=True).view(-1, *input_size[1:])
449 | if args.squeeze_first: fake_imgs = squeeze_layer.inverse(fake_imgs)
450 | fake_imgs = remove_padding(fake_imgs)
451 | fake_imgs = fake_imgs.view(-1, im_dim, args.imagesize, args.imagesize)
452 | fake_imgs = fake_imgs.cpu()
453 |
454 | if args.save_each:
455 | for j in range(fake_imgs.shape[0]):
456 | save_image(
457 | fake_imgs[j], '{}_imgs_t{}/{}.png'.format(args.data, args.temp, args.batchsize * i + j), nrow=1,
458 | padding=0, range=(0, 1), pad_value=0
459 | )
460 | else:
461 | save_image(
462 | fake_imgs, 'imgs/{}_t{}_samples{}.png'.format(args.data, args.temp, i), nrow=W, padding=2,
463 | range=(0, 1), pad_value=1
464 | )
465 |
466 |
467 | real_imgs = test_loader.__iter__().__next__()[0] if args.real else None
468 | if args.real:
469 | real_imgs = test_loader.__iter__().__next__()[0]
470 | save_image(
471 | real_imgs.cpu().float(), 'imgs/{}_real.png'.format(args.data), nrow=W, padding=2, range=(0, 1), pad_value=1
472 | )
473 |
474 | visualize(model)
475 |
--------------------------------------------------------------------------------
/resflows/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.datasets as vdsets
3 |
4 |
5 | class Dataset(object):
6 |
7 | def __init__(self, loc, transform=None, in_mem=True):
8 | self.in_mem = in_mem
9 | self.dataset = torch.load(loc)
10 | if in_mem: self.dataset = self.dataset.float().div(255)
11 | self.transform = transform
12 |
13 | def __len__(self):
14 | return self.dataset.size(0)
15 |
16 | @property
17 | def ndim(self):
18 | return self.dataset.size(1)
19 |
20 | def __getitem__(self, index):
21 | x = self.dataset[index]
22 | if not self.in_mem: x = x.float().div(255)
23 | x = self.transform(x) if self.transform is not None else x
24 | return x, 0
25 |
26 |
27 | class MNIST(object):
28 |
29 | def __init__(self, dataroot, train=True, transform=None):
30 | self.mnist = vdsets.MNIST(dataroot, train=train, download=True, transform=transform)
31 |
32 | def __len__(self):
33 | return len(self.mnist)
34 |
35 | @property
36 | def ndim(self):
37 | return 1
38 |
39 | def __getitem__(self, index):
40 | return self.mnist[index]
41 |
42 |
43 | class CIFAR10(object):
44 |
45 | def __init__(self, dataroot, train=True, transform=None):
46 | self.cifar10 = vdsets.CIFAR10(dataroot, train=train, download=True, transform=transform)
47 |
48 | def __len__(self):
49 | return len(self.cifar10)
50 |
51 | @property
52 | def ndim(self):
53 | return 3
54 |
55 | def __getitem__(self, index):
56 | return self.cifar10[index]
57 |
58 |
59 | class CelebA5bit(object):
60 |
61 | LOC = 'data/celebahq64_5bit/celeba_full_64x64_5bit.pth'
62 |
63 | def __init__(self, train=True, transform=None):
64 | self.dataset = torch.load(self.LOC).float().div(31)
65 | if not train:
66 | self.dataset = self.dataset[:5000]
67 | self.transform = transform
68 |
69 | def __len__(self):
70 | return self.dataset.size(0)
71 |
72 | @property
73 | def ndim(self):
74 | return self.dataset.size(1)
75 |
76 | def __getitem__(self, index):
77 | x = self.dataset[index]
78 | x = self.transform(x) if self.transform is not None else x
79 | return x, 0
80 |
81 |
82 | class CelebAHQ(Dataset):
83 | TRAIN_LOC = 'data/celebahq/celeba256_train.pth'
84 | TEST_LOC = 'data/celebahq/celeba256_validation.pth'
85 |
86 | def __init__(self, train=True, transform=None):
87 | return super(CelebAHQ, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform)
88 |
89 |
90 | class Imagenet32(Dataset):
91 | TRAIN_LOC = 'data/imagenet32/train_32x32.pth'
92 | TEST_LOC = 'data/imagenet32/valid_32x32.pth'
93 |
94 | def __init__(self, train=True, transform=None):
95 | return super(Imagenet32, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform)
96 |
97 |
98 | class Imagenet64(Dataset):
99 | TRAIN_LOC = 'data/imagenet64/train_64x64.pth'
100 | TEST_LOC = 'data/imagenet64/valid_64x64.pth'
101 |
102 | def __init__(self, train=True, transform=None):
103 | return super(Imagenet64, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform, in_mem=False)
104 |
--------------------------------------------------------------------------------
/resflows/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .act_norm import *
2 | from .container import *
3 | from .coupling import *
4 | from .elemwise import *
5 | from .iresblock import *
6 | from .normalization import *
7 | from .squeeze import *
8 | from .glow import *
9 |
--------------------------------------------------------------------------------
/resflows/layers/act_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import Parameter
4 |
5 | __all__ = ['ActNorm1d', 'ActNorm2d']
6 |
7 |
8 | class ActNormNd(nn.Module):
9 |
10 | def __init__(self, num_features, eps=1e-12):
11 | super(ActNormNd, self).__init__()
12 | self.num_features = num_features
13 | self.eps = eps
14 | self.weight = Parameter(torch.Tensor(num_features))
15 | self.bias = Parameter(torch.Tensor(num_features))
16 | self.register_buffer('initialized', torch.tensor(0))
17 |
18 | @property
19 | def shape(self):
20 | raise NotImplementedError
21 |
22 | def forward(self, x, logpx=None):
23 | c = x.size(1)
24 |
25 | if not self.initialized:
26 | with torch.no_grad():
27 | # compute batch statistics
28 | x_t = x.transpose(0, 1).contiguous().view(c, -1)
29 | batch_mean = torch.mean(x_t, dim=1)
30 | batch_var = torch.var(x_t, dim=1)
31 |
32 | # for numerical issues
33 | batch_var = torch.max(batch_var, torch.tensor(0.2).to(batch_var))
34 |
35 | self.bias.data.copy_(-batch_mean)
36 | self.weight.data.copy_(-0.5 * torch.log(batch_var))
37 | self.initialized.fill_(1)
38 |
39 | bias = self.bias.view(*self.shape).expand_as(x)
40 | weight = self.weight.view(*self.shape).expand_as(x)
41 |
42 | y = (x + bias) * torch.exp(weight)
43 |
44 | if logpx is None:
45 | return y
46 | else:
47 | return y, logpx - self._logdetgrad(x)
48 |
49 | def inverse(self, y, logpy=None):
50 | assert self.initialized
51 | bias = self.bias.view(*self.shape).expand_as(y)
52 | weight = self.weight.view(*self.shape).expand_as(y)
53 |
54 | x = y * torch.exp(-weight) - bias
55 |
56 | if logpy is None:
57 | return x
58 | else:
59 | return x, logpy + self._logdetgrad(x)
60 |
61 | def _logdetgrad(self, x):
62 | return self.weight.view(*self.shape).expand(*x.size()).contiguous().view(x.size(0), -1).sum(1, keepdim=True)
63 |
64 | def __repr__(self):
65 | return ('{name}({num_features})'.format(name=self.__class__.__name__, **self.__dict__))
66 |
67 |
68 | class ActNorm1d(ActNormNd):
69 |
70 | @property
71 | def shape(self):
72 | return [1, -1]
73 |
74 |
75 | class ActNorm2d(ActNormNd):
76 |
77 | @property
78 | def shape(self):
79 | return [1, -1, 1, 1]
80 |
--------------------------------------------------------------------------------
/resflows/layers/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .activations import *
2 | from .lipschitz import *
3 | from .mixed_lipschitz import *
4 |
--------------------------------------------------------------------------------
/resflows/layers/base/activations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class Identity(nn.Module):
7 |
8 | def forward(self, x):
9 | return x
10 |
11 |
12 | class FullSort(nn.Module):
13 |
14 | def forward(self, x):
15 | return torch.sort(x, 1)[0]
16 |
17 |
18 | class MaxMin(nn.Module):
19 |
20 | def forward(self, x):
21 | b, d = x.shape
22 | max_vals = torch.max(x.view(b, d // 2, 2), 2)[0]
23 | min_vals = torch.min(x.view(b, d // 2, 2), 2)[0]
24 | return torch.cat([max_vals, min_vals], 1)
25 |
26 |
27 | class LipschitzCube(nn.Module):
28 |
29 | def forward(self, x):
30 | return (x >= 1).to(x) * (x - 2 / 3) + (x <= -1).to(x) * (x + 2 / 3) + ((x > -1) * (x < 1)).to(x) * x**3 / 3
31 |
32 |
33 | class SwishFn(torch.autograd.Function):
34 |
35 | @staticmethod
36 | def forward(ctx, x, beta):
37 | beta_sigm = torch.sigmoid(beta * x)
38 | output = x * beta_sigm
39 | ctx.save_for_backward(x, output, beta)
40 | return output / 1.1
41 |
42 | @staticmethod
43 | def backward(ctx, grad_output):
44 | x, output, beta = ctx.saved_tensors
45 | beta_sigm = output / x
46 | grad_x = grad_output * (beta * output + beta_sigm * (1 - beta * output))
47 | grad_beta = torch.sum(grad_output * (x * output - output * output)).expand_as(beta)
48 | return grad_x / 1.1, grad_beta / 1.1
49 |
50 |
51 | class Swish(nn.Module):
52 |
53 | def __init__(self):
54 | super(Swish, self).__init__()
55 | self.beta = nn.Parameter(torch.tensor([0.5]))
56 |
57 | def forward(self, x):
58 | return (x * torch.sigmoid_(x * F.softplus(self.beta))).div_(1.1)
59 |
60 |
61 | if __name__ == '__main__':
62 |
63 | m = Swish()
64 | xx = torch.linspace(-5, 5, 1000).requires_grad_(True)
65 | yy = m(xx)
66 | dd, dbeta = torch.autograd.grad(yy.sum() * 2, [xx, m.beta])
67 |
68 | import matplotlib.pyplot as plt
69 |
70 | plt.plot(xx.detach().numpy(), yy.detach().numpy(), label='Func')
71 | plt.plot(xx.detach().numpy(), dd.detach().numpy(), label='Deriv')
72 | plt.plot(xx.detach().numpy(), torch.max(dd.detach().abs() - 1, torch.zeros_like(dd)).numpy(), label='|Deriv| > 1')
73 | plt.legend()
74 | plt.tight_layout()
75 | plt.show()
76 |
--------------------------------------------------------------------------------
/resflows/layers/base/lipschitz.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.init as init
5 | import torch.nn.functional as F
6 |
7 | from .utils import _pair
8 | from .mixed_lipschitz import InducedNormLinear, InducedNormConv2d
9 |
10 | __all__ = ['SpectralNormLinear', 'SpectralNormConv2d', 'LopLinear', 'LopConv2d', 'get_linear', 'get_conv2d']
11 |
12 |
13 | class SpectralNormLinear(nn.Module):
14 |
15 | def __init__(
16 | self, in_features, out_features, bias=True, coeff=0.97, n_iterations=None, atol=None, rtol=None, **unused_kwargs
17 | ):
18 | del unused_kwargs
19 | super(SpectralNormLinear, self).__init__()
20 | self.in_features = in_features
21 | self.out_features = out_features
22 | self.coeff = coeff
23 | self.n_iterations = n_iterations
24 | self.atol = atol
25 | self.rtol = rtol
26 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
27 | if bias:
28 | self.bias = nn.Parameter(torch.Tensor(out_features))
29 | else:
30 | self.register_parameter('bias', None)
31 | self.reset_parameters()
32 |
33 | h, w = self.weight.shape
34 | self.register_buffer('scale', torch.tensor(0.))
35 | self.register_buffer('u', F.normalize(self.weight.new_empty(h).normal_(0, 1), dim=0))
36 | self.register_buffer('v', F.normalize(self.weight.new_empty(w).normal_(0, 1), dim=0))
37 | self.compute_weight(True, 200)
38 |
39 | def reset_parameters(self):
40 | init.kaiming_uniform_(self.weight, a=math.sqrt(5))
41 | if self.bias is not None:
42 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
43 | bound = 1 / math.sqrt(fan_in)
44 | init.uniform_(self.bias, -bound, bound)
45 |
46 | def compute_weight(self, update=True, n_iterations=None, atol=None, rtol=None):
47 | n_iterations = self.n_iterations if n_iterations is None else n_iterations
48 | atol = self.atol if atol is None else atol
49 | rtol = self.rtol if rtol is None else atol
50 |
51 | if n_iterations is None and (atol is None or rtol is None):
52 | raise ValueError('Need one of n_iteration or (atol, rtol).')
53 |
54 | if n_iterations is None:
55 | n_iterations = 20000
56 |
57 | u = self.u
58 | v = self.v
59 | weight = self.weight
60 | if update:
61 | with torch.no_grad():
62 | itrs_used = 0.
63 | for _ in range(n_iterations):
64 | old_v = v.clone()
65 | old_u = u.clone()
66 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
67 | # are the first left and right singular vectors.
68 | # This power iteration produces approximations of `u` and `v`.
69 | v = F.normalize(torch.mv(weight.t(), u), dim=0, out=v)
70 | u = F.normalize(torch.mv(weight, v), dim=0, out=u)
71 | itrs_used = itrs_used + 1
72 | if atol is not None and rtol is not None:
73 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5)
74 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5)
75 | tol_u = atol + rtol * torch.max(u)
76 | tol_v = atol + rtol * torch.max(v)
77 | if err_u < tol_u and err_v < tol_v:
78 | break
79 | if itrs_used > 0:
80 | u = u.clone()
81 | v = v.clone()
82 |
83 | sigma = torch.dot(u, torch.mv(weight, v))
84 | with torch.no_grad():
85 | self.scale.copy_(sigma)
86 | # soft normalization: only when sigma larger than coeff
87 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff)
88 | weight = weight / factor
89 | return weight
90 |
91 | def forward(self, input):
92 | weight = self.compute_weight(update=self.training)
93 | return F.linear(input, weight, self.bias)
94 |
95 | def extra_repr(self):
96 | return 'in_features={}, out_features={}, bias={}, coeff={}, n_iters={}, atol={}, rtol={}'.format(
97 | self.in_features, self.out_features, self.bias is not None, self.coeff, self.n_iterations, self.atol,
98 | self.rtol
99 | )
100 |
101 |
102 | class SpectralNormConv2d(nn.Module):
103 |
104 | def __init__(
105 | self, in_channels, out_channels, kernel_size, stride, padding, bias=True, coeff=0.97, n_iterations=None,
106 | atol=None, rtol=None, **unused_kwargs
107 | ):
108 | del unused_kwargs
109 | super(SpectralNormConv2d, self).__init__()
110 | self.in_channels = in_channels
111 | self.out_channels = out_channels
112 | self.kernel_size = _pair(kernel_size)
113 | self.stride = _pair(stride)
114 | self.padding = _pair(padding)
115 | self.coeff = coeff
116 | self.n_iterations = n_iterations
117 | self.atol = atol
118 | self.rtol = rtol
119 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))
120 | if bias:
121 | self.bias = nn.Parameter(torch.Tensor(out_channels))
122 | else:
123 | self.register_parameter('bias', None)
124 | self.reset_parameters()
125 | self.initialized = False
126 | self.register_buffer('spatial_dims', torch.tensor([1., 1.]))
127 | self.register_buffer('scale', torch.tensor(0.))
128 |
129 | def reset_parameters(self):
130 | init.kaiming_uniform_(self.weight, a=math.sqrt(5))
131 | if self.bias is not None:
132 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
133 | bound = 1 / math.sqrt(fan_in)
134 | init.uniform_(self.bias, -bound, bound)
135 |
136 | def _initialize_u_v(self):
137 | if self.kernel_size == (1, 1):
138 | self.register_buffer('u', F.normalize(self.weight.new_empty(self.out_channels).normal_(0, 1), dim=0))
139 | self.register_buffer('v', F.normalize(self.weight.new_empty(self.in_channels).normal_(0, 1), dim=0))
140 | else:
141 | c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item())
142 | with torch.no_grad():
143 | num_input_dim = c * h * w
144 | v = F.normalize(torch.randn(num_input_dim).to(self.weight), dim=0, eps=1e-12)
145 | # forward call to infer the shape
146 | u = F.conv2d(v.view(1, c, h, w), self.weight, stride=self.stride, padding=self.padding, bias=None)
147 | num_output_dim = u.shape[0] * u.shape[1] * u.shape[2] * u.shape[3]
148 | self.out_shape = u.shape
149 | # overwrite u with random init
150 | u = F.normalize(torch.randn(num_output_dim).to(self.weight), dim=0, eps=1e-12)
151 |
152 | self.register_buffer('u', u)
153 | self.register_buffer('v', v)
154 |
155 | def compute_weight(self, update=True, n_iterations=None):
156 | if not self.initialized:
157 | self._initialize_u_v()
158 | self.initialized = True
159 |
160 | if self.kernel_size == (1, 1):
161 | return self._compute_weight_1x1(update, n_iterations)
162 | else:
163 | return self._compute_weight_kxk(update, n_iterations)
164 |
165 | def _compute_weight_1x1(self, update=True, n_iterations=None, atol=None, rtol=None):
166 | n_iterations = self.n_iterations if n_iterations is None else n_iterations
167 | atol = self.atol if atol is None else atol
168 | rtol = self.rtol if rtol is None else atol
169 |
170 | if n_iterations is None and (atol is None or rtol is None):
171 | raise ValueError('Need one of n_iteration or (atol, rtol).')
172 |
173 | if n_iterations is None:
174 | n_iterations = 20000
175 |
176 | u = self.u
177 | v = self.v
178 | weight = self.weight.view(self.out_channels, self.in_channels)
179 | if update:
180 | with torch.no_grad():
181 | itrs_used = 0
182 | for _ in range(n_iterations):
183 | old_v = v.clone()
184 | old_u = u.clone()
185 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
186 | # are the first left and right singular vectors.
187 | # This power iteration produces approximations of `u` and `v`.
188 | v = F.normalize(torch.mv(weight.t(), u), dim=0, out=v)
189 | u = F.normalize(torch.mv(weight, v), dim=0, out=u)
190 | itrs_used = itrs_used + 1
191 | if atol is not None and rtol is not None:
192 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5)
193 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5)
194 | tol_u = atol + rtol * torch.max(u)
195 | tol_v = atol + rtol * torch.max(v)
196 | if err_u < tol_u and err_v < tol_v:
197 | break
198 | if itrs_used > 0:
199 | u = u.clone()
200 | v = v.clone()
201 |
202 | sigma = torch.dot(u, torch.mv(weight, v))
203 | with torch.no_grad():
204 | self.scale.copy_(sigma)
205 | # soft normalization: only when sigma larger than coeff
206 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff)
207 | weight = weight / factor
208 | return weight.view(self.out_channels, self.in_channels, 1, 1)
209 |
210 | def _compute_weight_kxk(self, update=True, n_iterations=None, atol=None, rtol=None):
211 | n_iterations = self.n_iterations if n_iterations is None else n_iterations
212 | atol = self.atol if atol is None else atol
213 | rtol = self.rtol if rtol is None else atol
214 |
215 | if n_iterations is None and (atol is None or rtol is None):
216 | raise ValueError('Need one of n_iteration or (atol, rtol).')
217 |
218 | if n_iterations is None:
219 | n_iterations = 20000
220 |
221 | u = self.u
222 | v = self.v
223 | weight = self.weight
224 | c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item())
225 | if update:
226 | with torch.no_grad():
227 | itrs_used = 0
228 | for _ in range(n_iterations):
229 | old_u = u.clone()
230 | old_v = v.clone()
231 | v_s = F.conv_transpose2d(
232 | u.view(self.out_shape), weight, stride=self.stride, padding=self.padding, output_padding=0
233 | )
234 | v = F.normalize(v_s.view(-1), dim=0, out=v)
235 | u_s = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None)
236 | u = F.normalize(u_s.view(-1), dim=0, out=u)
237 | itrs_used = itrs_used + 1
238 | if atol is not None and rtol is not None:
239 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5)
240 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5)
241 | tol_u = atol + rtol * torch.max(u)
242 | tol_v = atol + rtol * torch.max(v)
243 | if err_u < tol_u and err_v < tol_v:
244 | break
245 | if itrs_used > 0:
246 | u = u.clone()
247 | v = v.clone()
248 |
249 | weight_v = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None)
250 | weight_v = weight_v.view(-1)
251 | sigma = torch.dot(u.view(-1), weight_v)
252 | with torch.no_grad():
253 | self.scale.copy_(sigma)
254 | # soft normalization: only when sigma larger than coeff
255 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff)
256 | weight = weight / factor
257 | return weight
258 |
259 | def forward(self, input):
260 | if not self.initialized: self.spatial_dims.copy_(torch.tensor(input.shape[2:4]).to(self.spatial_dims))
261 | weight = self.compute_weight(update=self.training)
262 | return F.conv2d(input, weight, self.bias, self.stride, self.padding, 1, 1)
263 |
264 | def extra_repr(self):
265 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}')
266 | if self.padding != (0,) * len(self.padding):
267 | s += ', padding={padding}'
268 | if self.bias is None:
269 | s += ', bias=False'
270 | s += ', coeff={}, n_iters={}, atol={}, rtol={}'.format(self.coeff, self.n_iterations, self.atol, self.rtol)
271 | return s.format(**self.__dict__)
272 |
273 |
274 | class LopLinear(nn.Linear):
275 | """Lipschitz constant defined using operator norms."""
276 |
277 | def __init__(
278 | self,
279 | in_features,
280 | out_features,
281 | bias=True,
282 | coeff=0.97,
283 | domain=float('inf'),
284 | codomain=float('inf'),
285 | local_constraint=True,
286 | **unused_kwargs,
287 | ):
288 | del unused_kwargs
289 | super(LopLinear, self).__init__(in_features, out_features, bias)
290 | self.coeff = coeff
291 | self.domain = domain
292 | self.codomain = codomain
293 | self.local_constraint = local_constraint
294 | max_across_input_dims, self.norm_type = operator_norm_settings(self.domain, self.codomain)
295 | self.max_across_dim = 1 if max_across_input_dims else 0
296 | self.register_buffer('scale', torch.tensor(0.))
297 |
298 | def compute_weight(self):
299 | scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim)
300 | if not self.local_constraint: scale = scale.max()
301 | with torch.no_grad():
302 | self.scale.copy_(scale.max())
303 |
304 | # soft normalization
305 | factor = torch.max(torch.ones(1).to(self.weight), scale / self.coeff)
306 |
307 | return self.weight / factor
308 |
309 | def forward(self, input):
310 | weight = self.compute_weight()
311 | return F.linear(input, weight, self.bias)
312 |
313 | def extra_repr(self):
314 | s = super(LopLinear, self).extra_repr()
315 | return s + ', coeff={}, domain={}, codomain={}, local={}'.format(
316 | self.coeff, self.domain, self.codomain, self.local_constraint
317 | )
318 |
319 |
320 | class LopConv2d(nn.Conv2d):
321 | """Lipschitz constant defined using operator norms."""
322 |
323 | def __init__(
324 | self,
325 | in_channels,
326 | out_channels,
327 | kernel_size,
328 | stride,
329 | padding,
330 | bias=True,
331 | coeff=0.97,
332 | domain=float('inf'),
333 | codomain=float('inf'),
334 | local_constraint=True,
335 | **unused_kwargs,
336 | ):
337 | del unused_kwargs
338 | super(LopConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, bias)
339 | self.coeff = coeff
340 | self.domain = domain
341 | self.codomain = codomain
342 | self.local_constraint = local_constraint
343 | max_across_input_dims, self.norm_type = operator_norm_settings(self.domain, self.codomain)
344 | self.max_across_dim = 1 if max_across_input_dims else 0
345 | self.register_buffer('scale', torch.tensor(0.))
346 |
347 | def compute_weight(self):
348 | scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim)
349 | if not self.local_constraint: scale = scale.max()
350 | with torch.no_grad():
351 | self.scale.copy_(scale.max())
352 |
353 | # soft normalization
354 | factor = torch.max(torch.ones(1).to(self.weight.device), scale / self.coeff)
355 |
356 | return self.weight / factor
357 |
358 | def forward(self, input):
359 | weight = self.compute_weight()
360 | return F.conv2d(input, weight, self.bias, self.stride, self.padding, 1, 1)
361 |
362 | def extra_repr(self):
363 | s = super(LopConv2d, self).extra_repr()
364 | return s + ', coeff={}, domain={}, codomain={}, local={}'.format(
365 | self.coeff, self.domain, self.codomain, self.local_constraint
366 | )
367 |
368 |
369 | class LipNormLinear(nn.Linear):
370 | """Lipschitz constant defined using operator norms."""
371 |
372 | def __init__(
373 | self,
374 | in_features,
375 | out_features,
376 | bias=True,
377 | coeff=0.97,
378 | domain=float('inf'),
379 | codomain=float('inf'),
380 | local_constraint=True,
381 | **unused_kwargs,
382 | ):
383 | del unused_kwargs
384 | super(LipNormLinear, self).__init__(in_features, out_features, bias)
385 | self.coeff = coeff
386 | self.domain = domain
387 | self.codomain = codomain
388 | self.local_constraint = local_constraint
389 | max_across_input_dims, self.norm_type = operator_norm_settings(self.domain, self.codomain)
390 | self.max_across_dim = 1 if max_across_input_dims else 0
391 |
392 | # Initialize scale parameter.
393 | with torch.no_grad():
394 | w_scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim)
395 | if not self.local_constraint: w_scale = w_scale.max()
396 | self.scale = nn.Parameter(_logit(w_scale / self.coeff))
397 |
398 | def compute_weight(self):
399 | w_scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim)
400 | if not self.local_constraint: w_scale = w_scale.max()
401 | return self.weight / w_scale * torch.sigmoid(self.scale) * self.coeff
402 |
403 | def forward(self, input):
404 | weight = self.compute_weight()
405 | return F.linear(input, weight, self.bias)
406 |
407 | def extra_repr(self):
408 | s = super(LipNormLinear, self).extra_repr()
409 | return s + ', coeff={}, domain={}, codomain={}, local={}'.format(
410 | self.coeff, self.domain, self.codomain, self.local_constraint
411 | )
412 |
413 |
414 | class LipNormConv2d(nn.Conv2d):
415 | """Lipschitz constant defined using operator norms."""
416 |
417 | def __init__(
418 | self,
419 | in_channels,
420 | out_channels,
421 | kernel_size,
422 | stride,
423 | padding,
424 | bias=True,
425 | coeff=0.97,
426 | domain=float('inf'),
427 | codomain=float('inf'),
428 | local_constraint=True,
429 | **unused_kwargs,
430 | ):
431 | del unused_kwargs
432 | super(LipNormConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, bias)
433 | self.coeff = coeff
434 | self.domain = domain
435 | self.codomain = codomain
436 | self.local_constraint = local_constraint
437 | max_across_input_dims, self.norm_type = operator_norm_settings(self.domain, self.codomain)
438 | self.max_across_dim = 1 if max_across_input_dims else 0
439 |
440 | # Initialize scale parameter.
441 | with torch.no_grad():
442 | w_scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim)
443 | if not self.local_constraint: w_scale = w_scale.max()
444 | self.scale = nn.Parameter(_logit(w_scale / self.coeff))
445 |
446 | def compute_weight(self):
447 | w_scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim)
448 | if not self.local_constraint: w_scale = w_scale.max()
449 | return self.weight / w_scale * torch.sigmoid(self.scale)
450 |
451 | def forward(self, input):
452 | weight = self.compute_weight()
453 | return F.conv2d(input, weight, self.bias, self.stride, self.padding, 1, 1)
454 |
455 | def extra_repr(self):
456 | s = super(LipNormConv2d, self).extra_repr()
457 | return s + ', coeff={}, domain={}, codomain={}, local={}'.format(
458 | self.coeff, self.domain, self.codomain, self.local_constraint
459 | )
460 |
461 |
462 | def _logit(p):
463 | p = torch.max(torch.ones(1) * 0.1, torch.min(torch.ones(1) * 0.9, p))
464 | return torch.log(p + 1e-10) + torch.log(1 - p + 1e-10)
465 |
466 |
467 | def _norm_except_dim(w, norm_type, dim):
468 | if norm_type == 1 or norm_type == 2:
469 | return torch.norm_except_dim(w, norm_type, dim)
470 | elif norm_type == float('inf'):
471 | return _max_except_dim(w, dim)
472 |
473 |
474 | def _max_except_dim(input, dim):
475 | maxed = input
476 | for axis in range(input.ndimension() - 1, dim, -1):
477 | maxed, _ = maxed.max(axis, keepdim=True)
478 | for axis in range(dim - 1, -1, -1):
479 | maxed, _ = maxed.max(axis, keepdim=True)
480 | return maxed
481 |
482 |
483 | def operator_norm_settings(domain, codomain):
484 | if domain == 1 and codomain == 1:
485 | # maximum l1-norm of column
486 | max_across_input_dims = True
487 | norm_type = 1
488 | elif domain == 1 and codomain == 2:
489 | # maximum l2-norm of column
490 | max_across_input_dims = True
491 | norm_type = 2
492 | elif domain == 1 and codomain == float("inf"):
493 | # maximum l-inf norm of column
494 | max_across_input_dims = True
495 | norm_type = float("inf")
496 | elif domain == 2 and codomain == float("inf"):
497 | # maximum l2-norm of row
498 | max_across_input_dims = False
499 | norm_type = 2
500 | elif domain == float("inf") and codomain == float("inf"):
501 | # maximum l1-norm of row
502 | max_across_input_dims = False
503 | norm_type = 1
504 | else:
505 | raise ValueError('Unknown combination of domain "{}" and codomain "{}"'.format(domain, codomain))
506 |
507 | return max_across_input_dims, norm_type
508 |
509 |
510 | def get_linear(in_features, out_features, bias=True, coeff=0.97, domain=None, codomain=None, **kwargs):
511 | _linear = InducedNormLinear
512 | if domain == 1:
513 | if codomain in [1, 2, float('inf')]:
514 | _linear = LopLinear
515 | elif codomain == float('inf'):
516 | if domain in [2, float('inf')]:
517 | _linear = LopLinear
518 | return _linear(in_features, out_features, bias, coeff, domain, codomain, **kwargs)
519 |
520 |
521 | def get_conv2d(
522 | in_channels, out_channels, kernel_size, stride, padding, bias=True, coeff=0.97, domain=None, codomain=None, **kwargs
523 | ):
524 | _conv2d = InducedNormConv2d
525 | if domain == 1:
526 | if codomain in [1, 2, float('inf')]:
527 | _conv2d = LopConv2d
528 | elif codomain == float('inf'):
529 | if domain in [2, float('inf')]:
530 | _conv2d = LopConv2d
531 | return _conv2d(in_channels, out_channels, kernel_size, stride, padding, bias, coeff, domain, codomain, **kwargs)
532 |
--------------------------------------------------------------------------------
/resflows/layers/base/mixed_lipschitz.py:
--------------------------------------------------------------------------------
1 | import collections.abc as container_abcs
2 | from itertools import repeat
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.init as init
7 | import torch.nn.functional as F
8 |
9 | __all__ = ['InducedNormLinear', 'InducedNormConv2d']
10 |
11 |
12 | class InducedNormLinear(nn.Module):
13 |
14 | def __init__(
15 | self, in_features, out_features, bias=True, coeff=0.97, domain=2, codomain=2, n_iterations=None, atol=None,
16 | rtol=None, zero_init=False, **unused_kwargs
17 | ):
18 | del unused_kwargs
19 | super(InducedNormLinear, self).__init__()
20 | self.in_features = in_features
21 | self.out_features = out_features
22 | self.coeff = coeff
23 | self.n_iterations = n_iterations
24 | self.atol = atol
25 | self.rtol = rtol
26 | self.domain = domain
27 | self.codomain = codomain
28 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
29 | if bias:
30 | self.bias = nn.Parameter(torch.Tensor(out_features))
31 | else:
32 | self.register_parameter('bias', None)
33 | self.reset_parameters(zero_init)
34 |
35 | with torch.no_grad():
36 | domain, codomain = self.compute_domain_codomain()
37 |
38 | h, w = self.weight.shape
39 | self.register_buffer('scale', torch.tensor(0.))
40 | self.register_buffer('u', normalize_u(self.weight.new_empty(h).normal_(0, 1), codomain))
41 | self.register_buffer('v', normalize_v(self.weight.new_empty(w).normal_(0, 1), domain))
42 |
43 | # Try different random seeds to find the best u and v.
44 | with torch.no_grad():
45 | self.compute_weight(True, n_iterations=200, atol=None, rtol=None)
46 | best_scale = self.scale.clone()
47 | best_u, best_v = self.u.clone(), self.v.clone()
48 | if not (domain == 2 and codomain == 2):
49 | for _ in range(10):
50 | self.register_buffer('u', normalize_u(self.weight.new_empty(h).normal_(0, 1), codomain))
51 | self.register_buffer('v', normalize_v(self.weight.new_empty(w).normal_(0, 1), domain))
52 | self.compute_weight(True, n_iterations=200)
53 | if self.scale > best_scale:
54 | best_u, best_v = self.u.clone(), self.v.clone()
55 | self.u.copy_(best_u)
56 | self.v.copy_(best_v)
57 |
58 | def reset_parameters(self, zero_init=False):
59 | init.kaiming_uniform_(self.weight, a=math.sqrt(5))
60 | if zero_init:
61 | # normalize cannot handle zero weight in some cases.
62 | self.weight.data.div_(1000)
63 | if self.bias is not None:
64 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
65 | bound = 1 / math.sqrt(fan_in)
66 | init.uniform_(self.bias, -bound, bound)
67 |
68 | def compute_domain_codomain(self):
69 | if torch.is_tensor(self.domain):
70 | domain = asym_squash(self.domain)
71 | codomain = asym_squash(self.codomain)
72 | else:
73 | domain, codomain = self.domain, self.codomain
74 | return domain, codomain
75 |
76 | def compute_one_iter(self):
77 | domain, codomain = self.compute_domain_codomain()
78 | u = self.u.detach()
79 | v = self.v.detach()
80 | weight = self.weight.detach()
81 | u = normalize_u(torch.mv(weight, v), codomain)
82 | v = normalize_v(torch.mv(weight.t(), u), domain)
83 | return torch.dot(u, torch.mv(weight, v))
84 |
85 | def compute_weight(self, update=True, n_iterations=None, atol=None, rtol=None):
86 | u = self.u
87 | v = self.v
88 | weight = self.weight
89 |
90 | if update:
91 |
92 | n_iterations = self.n_iterations if n_iterations is None else n_iterations
93 | atol = self.atol if atol is None else atol
94 | rtol = self.rtol if rtol is None else atol
95 |
96 | if n_iterations is None and (atol is None or rtol is None):
97 | raise ValueError('Need one of n_iteration or (atol, rtol).')
98 |
99 | max_itrs = 200
100 | if n_iterations is not None:
101 | max_itrs = n_iterations
102 |
103 | with torch.no_grad():
104 | domain, codomain = self.compute_domain_codomain()
105 | for _ in range(max_itrs):
106 | # Algorithm from http://www.qetlab.com/InducedMatrixNorm.
107 | if n_iterations is None and atol is not None and rtol is not None:
108 | old_v = v.clone()
109 | old_u = u.clone()
110 |
111 | u = normalize_u(torch.mv(weight, v), codomain, out=u)
112 | v = normalize_v(torch.mv(weight.t(), u), domain, out=v)
113 |
114 | if n_iterations is None and atol is not None and rtol is not None:
115 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5)
116 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5)
117 | tol_u = atol + rtol * torch.max(u)
118 | tol_v = atol + rtol * torch.max(v)
119 | if err_u < tol_u and err_v < tol_v:
120 | break
121 | self.v.copy_(v)
122 | self.u.copy_(u)
123 | u = u.clone()
124 | v = v.clone()
125 |
126 | sigma = torch.dot(u, torch.mv(weight, v))
127 | with torch.no_grad():
128 | self.scale.copy_(sigma)
129 | # soft normalization: only when sigma larger than coeff
130 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff)
131 | weight = weight / factor
132 | return weight
133 |
134 | def forward(self, input):
135 | weight = self.compute_weight(update=False)
136 | return F.linear(input, weight, self.bias)
137 |
138 | def extra_repr(self):
139 | domain, codomain = self.compute_domain_codomain()
140 | return (
141 | 'in_features={}, out_features={}, bias={}'
142 | ', coeff={}, domain={:.2f}, codomain={:.2f}, n_iters={}, atol={}, rtol={}, learnable_ord={}'.format(
143 | self.in_features, self.out_features, self.bias is not None, self.coeff, domain, codomain,
144 | self.n_iterations, self.atol, self.rtol, torch.is_tensor(self.domain)
145 | )
146 | )
147 |
148 |
149 | class InducedNormConv2d(nn.Module):
150 |
151 | def __init__(
152 | self, in_channels, out_channels, kernel_size, stride, padding, bias=True, coeff=0.97, domain=2, codomain=2,
153 | n_iterations=None, atol=None, rtol=None, **unused_kwargs
154 | ):
155 | del unused_kwargs
156 | super(InducedNormConv2d, self).__init__()
157 | self.in_channels = in_channels
158 | self.out_channels = out_channels
159 | self.kernel_size = _pair(kernel_size)
160 | self.stride = _pair(stride)
161 | self.padding = _pair(padding)
162 | self.coeff = coeff
163 | self.n_iterations = n_iterations
164 | self.domain = domain
165 | self.codomain = codomain
166 | self.atol = atol
167 | self.rtol = rtol
168 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))
169 | if bias:
170 | self.bias = nn.Parameter(torch.Tensor(out_channels))
171 | else:
172 | self.register_parameter('bias', None)
173 | self.reset_parameters()
174 | self.register_buffer('initialized', torch.tensor(0))
175 | self.register_buffer('spatial_dims', torch.tensor([1., 1.]))
176 | self.register_buffer('scale', torch.tensor(0.))
177 | self.register_buffer('u', self.weight.new_empty(self.out_channels))
178 | self.register_buffer('v', self.weight.new_empty(self.in_channels))
179 |
180 | def compute_domain_codomain(self):
181 | if torch.is_tensor(self.domain):
182 | domain = asym_squash(self.domain)
183 | codomain = asym_squash(self.codomain)
184 | else:
185 | domain, codomain = self.domain, self.codomain
186 | return domain, codomain
187 |
188 | def reset_parameters(self):
189 | init.kaiming_uniform_(self.weight, a=math.sqrt(5))
190 | if self.bias is not None:
191 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
192 | bound = 1 / math.sqrt(fan_in)
193 | init.uniform_(self.bias, -bound, bound)
194 |
195 | def _initialize_u_v(self):
196 | with torch.no_grad():
197 | domain, codomain = self.compute_domain_codomain()
198 | if self.kernel_size == (1, 1):
199 | self.u.resize_(self.out_channels).normal_(0, 1)
200 | self.u.copy_(normalize_u(self.u, codomain))
201 | self.v.resize_(self.in_channels).normal_(0, 1)
202 | self.v.copy_(normalize_v(self.v, domain))
203 | else:
204 | c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item())
205 | with torch.no_grad():
206 | num_input_dim = c * h * w
207 | self.v.resize_(num_input_dim).normal_(0, 1)
208 | self.v.copy_(normalize_v(self.v, domain))
209 | # forward call to infer the shape
210 | u = F.conv2d(
211 | self.v.view(1, c, h, w), self.weight, stride=self.stride, padding=self.padding, bias=None
212 | )
213 | num_output_dim = u.shape[0] * u.shape[1] * u.shape[2] * u.shape[3]
214 | # overwrite u with random init
215 | self.u.resize_(num_output_dim).normal_(0, 1)
216 | self.u.copy_(normalize_u(self.u, codomain))
217 |
218 | self.initialized.fill_(1)
219 |
220 | # Try different random seeds to find the best u and v.
221 | self.compute_weight(True)
222 | best_scale = self.scale.clone()
223 | best_u, best_v = self.u.clone(), self.v.clone()
224 | if not (domain == 2 and codomain == 2):
225 | for _ in range(10):
226 | if self.kernel_size == (1, 1):
227 | self.u.copy_(normalize_u(self.weight.new_empty(self.out_channels).normal_(0, 1), codomain))
228 | self.v.copy_(normalize_v(self.weight.new_empty(self.in_channels).normal_(0, 1), domain))
229 | else:
230 | self.u.copy_(normalize_u(torch.randn(num_output_dim).to(self.weight), codomain))
231 | self.v.copy_(normalize_v(torch.randn(num_input_dim).to(self.weight), domain))
232 | self.compute_weight(True, n_iterations=200)
233 | if self.scale > best_scale:
234 | best_u, best_v = self.u.clone(), self.v.clone()
235 | self.u.copy_(best_u)
236 | self.v.copy_(best_v)
237 |
238 | def compute_one_iter(self):
239 | if not self.initialized:
240 | raise ValueError('Layer needs to be initialized first.')
241 | domain, codomain = self.compute_domain_codomain()
242 | if self.kernel_size == (1, 1):
243 | u = self.u.detach()
244 | v = self.v.detach()
245 | weight = self.weight.detach().view(self.out_channels, self.in_channels)
246 | u = normalize_u(torch.mv(weight, v), codomain)
247 | v = normalize_v(torch.mv(weight.t(), u), domain)
248 | return torch.dot(u, torch.mv(weight, v))
249 | else:
250 | u = self.u.detach()
251 | v = self.v.detach()
252 | weight = self.weight.detach()
253 | c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item())
254 | u_s = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None)
255 | out_shape = u_s.shape
256 | u = normalize_u(u_s.view(-1), codomain)
257 | v_s = F.conv_transpose2d(
258 | u.view(out_shape), weight, stride=self.stride, padding=self.padding, output_padding=0
259 | )
260 | v = normalize_v(v_s.view(-1), domain)
261 | weight_v = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None)
262 | return torch.dot(u.view(-1), weight_v.view(-1))
263 |
264 | def compute_weight(self, update=True, n_iterations=None, atol=None, rtol=None):
265 | if not self.initialized:
266 | self._initialize_u_v()
267 |
268 | if self.kernel_size == (1, 1):
269 | return self._compute_weight_1x1(update, n_iterations, atol, rtol)
270 | else:
271 | return self._compute_weight_kxk(update, n_iterations, atol, rtol)
272 |
273 | def _compute_weight_1x1(self, update=True, n_iterations=None, atol=None, rtol=None):
274 | n_iterations = self.n_iterations if n_iterations is None else n_iterations
275 | atol = self.atol if atol is None else atol
276 | rtol = self.rtol if rtol is None else atol
277 |
278 | if n_iterations is None and (atol is None or rtol is None):
279 | raise ValueError('Need one of n_iteration or (atol, rtol).')
280 |
281 | max_itrs = 200
282 | if n_iterations is not None:
283 | max_itrs = n_iterations
284 |
285 | u = self.u
286 | v = self.v
287 | weight = self.weight.view(self.out_channels, self.in_channels)
288 | if update:
289 | with torch.no_grad():
290 | domain, codomain = self.compute_domain_codomain()
291 | itrs_used = 0
292 | for _ in range(max_itrs):
293 | old_v = v.clone()
294 | old_u = u.clone()
295 |
296 | u = normalize_u(torch.mv(weight, v), codomain, out=u)
297 | v = normalize_v(torch.mv(weight.t(), u), domain, out=v)
298 |
299 | itrs_used = itrs_used + 1
300 |
301 | if n_iterations is None and atol is not None and rtol is not None:
302 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5)
303 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5)
304 | tol_u = atol + rtol * torch.max(u)
305 | tol_v = atol + rtol * torch.max(v)
306 | if err_u < tol_u and err_v < tol_v:
307 | break
308 | if itrs_used > 0:
309 | if domain != 1 and domain != 2:
310 | self.v.copy_(v)
311 | if codomain != 2 and codomain != float('inf'):
312 | self.u.copy_(u)
313 | u = u.clone()
314 | v = v.clone()
315 |
316 | sigma = torch.dot(u, torch.mv(weight, v))
317 | with torch.no_grad():
318 | self.scale.copy_(sigma)
319 | # soft normalization: only when sigma larger than coeff
320 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff)
321 | weight = weight / factor
322 | return weight.view(self.out_channels, self.in_channels, 1, 1)
323 |
324 | def _compute_weight_kxk(self, update=True, n_iterations=None, atol=None, rtol=None):
325 | n_iterations = self.n_iterations if n_iterations is None else n_iterations
326 | atol = self.atol if atol is None else atol
327 | rtol = self.rtol if rtol is None else atol
328 |
329 | if n_iterations is None and (atol is None or rtol is None):
330 | raise ValueError('Need one of n_iteration or (atol, rtol).')
331 |
332 | max_itrs = 200
333 | if n_iterations is not None:
334 | max_itrs = n_iterations
335 |
336 | u = self.u
337 | v = self.v
338 | weight = self.weight
339 | c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item())
340 | if update:
341 | with torch.no_grad():
342 | domain, codomain = self.compute_domain_codomain()
343 | itrs_used = 0
344 | for _ in range(max_itrs):
345 | old_u = u.clone()
346 | old_v = v.clone()
347 |
348 | u_s = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None)
349 | out_shape = u_s.shape
350 | u = normalize_u(u_s.view(-1), codomain, out=u)
351 |
352 | v_s = F.conv_transpose2d(
353 | u.view(out_shape), weight, stride=self.stride, padding=self.padding, output_padding=0
354 | )
355 | v = normalize_v(v_s.view(-1), domain, out=v)
356 |
357 | itrs_used = itrs_used + 1
358 | if n_iterations is None and atol is not None and rtol is not None:
359 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5)
360 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5)
361 | tol_u = atol + rtol * torch.max(u)
362 | tol_v = atol + rtol * torch.max(v)
363 | if err_u < tol_u and err_v < tol_v:
364 | break
365 | if itrs_used > 0:
366 | if domain != 2:
367 | self.v.copy_(v)
368 | if codomain != 2:
369 | self.u.copy_(u)
370 | v = v.clone()
371 | u = u.clone()
372 |
373 | weight_v = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None)
374 | weight_v = weight_v.view(-1)
375 | sigma = torch.dot(u.view(-1), weight_v)
376 | with torch.no_grad():
377 | self.scale.copy_(sigma)
378 | # soft normalization: only when sigma larger than coeff
379 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff)
380 | weight = weight / factor
381 | return weight
382 |
383 | def forward(self, input):
384 | if not self.initialized: self.spatial_dims.copy_(torch.tensor(input.shape[2:4]).to(self.spatial_dims))
385 | weight = self.compute_weight(update=False)
386 | return F.conv2d(input, weight, self.bias, self.stride, self.padding, 1, 1)
387 |
388 | def extra_repr(self):
389 | domain, codomain = self.compute_domain_codomain()
390 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}')
391 | if self.padding != (0,) * len(self.padding):
392 | s += ', padding={padding}'
393 | if self.bias is None:
394 | s += ', bias=False'
395 | s += ', coeff={}, domain={:.2f}, codomain={:.2f}, n_iters={}, atol={}, rtol={}, learnable_ord={}'.format(
396 | self.coeff, domain, codomain, self.n_iterations, self.atol, self.rtol, torch.is_tensor(self.domain)
397 | )
398 | return s.format(**self.__dict__)
399 |
400 |
401 | def projmax_(v):
402 | """Inplace argmax on absolute value."""
403 | ind = torch.argmax(torch.abs(v))
404 | v.zero_()
405 | v[ind] = 1
406 | return v
407 |
408 |
409 | def normalize_v(v, domain, out=None):
410 | if not torch.is_tensor(domain) and domain == 2:
411 | v = F.normalize(v, p=2, dim=0, out=out)
412 | elif domain == 1:
413 | v = projmax_(v)
414 | else:
415 | vabs = torch.abs(v)
416 | vph = v / vabs
417 | vph[torch.isnan(vph)] = 1
418 | vabs = vabs / torch.max(vabs)
419 | vabs = vabs**(1 / (domain - 1))
420 | v = vph * vabs / vector_norm(vabs, domain)
421 | return v
422 |
423 |
424 | def normalize_u(u, codomain, out=None):
425 | if not torch.is_tensor(codomain) and codomain == 2:
426 | u = F.normalize(u, p=2, dim=0, out=out)
427 | elif codomain == float('inf'):
428 | u = projmax_(u)
429 | else:
430 | uabs = torch.abs(u)
431 | uph = u / uabs
432 | uph[torch.isnan(uph)] = 1
433 | uabs = uabs / torch.max(uabs)
434 | uabs = uabs**(codomain - 1)
435 | if codomain == 1:
436 | u = uph * uabs / vector_norm(uabs, float('inf'))
437 | else:
438 | u = uph * uabs / vector_norm(uabs, codomain / (codomain - 1))
439 | return u
440 |
441 |
442 | def vector_norm(x, p):
443 | x = x.view(-1)
444 | return torch.sum(x**p)**(1 / p)
445 |
446 |
447 | def leaky_elu(x, a=0.3):
448 | return a * x + (1 - a) * F.elu(x)
449 |
450 |
451 | def asym_squash(x):
452 | return torch.tanh(-leaky_elu(-x + 0.5493061829986572)) * 2 + 3
453 |
454 |
455 | # def asym_squash(x):
456 | # return torch.tanh(x) / 2. + 2.
457 |
458 |
459 | def _ntuple(n):
460 |
461 | def parse(x):
462 | if isinstance(x, container_abcs.Iterable):
463 | return x
464 | return tuple(repeat(x, n))
465 |
466 | return parse
467 |
468 |
469 | _single = _ntuple(1)
470 | _pair = _ntuple(2)
471 | _triple = _ntuple(3)
472 | _quadruple = _ntuple(4)
473 |
474 | if __name__ == '__main__':
475 |
476 | p = nn.Parameter(torch.tensor(2.1))
477 |
478 | m = InducedNormConv2d(10, 2, 3, 1, 1, atol=1e-3, rtol=1e-3, domain=p, codomain=p)
479 | W = m.compute_weight()
480 |
481 | m.compute_one_iter().backward()
482 | print(p.grad)
483 |
484 | # m.weight.data.copy_(W)
485 | # W = m.compute_weight().cpu().detach().numpy()
486 | # import numpy as np
487 | # print(
488 | # '{} {} {}'.format(
489 | # np.linalg.norm(W, ord=2, axis=(0, 1)),
490 | # '>' if np.linalg.norm(W, ord=2, axis=(0, 1)) > m.scale else '<',
491 | # m.scale,
492 | # )
493 | # )
494 |
--------------------------------------------------------------------------------
/resflows/layers/base/utils.py:
--------------------------------------------------------------------------------
1 | import collections.abc as container_abcs
2 | from itertools import repeat
3 |
4 |
5 | def _ntuple(n):
6 |
7 | def parse(x):
8 | if isinstance(x, container_abcs.Iterable):
9 | return x
10 | return tuple(repeat(x, n))
11 |
12 | return parse
13 |
14 |
15 | _single = _ntuple(1)
16 | _pair = _ntuple(2)
17 | _triple = _ntuple(3)
18 | _quadruple = _ntuple(4)
19 |
--------------------------------------------------------------------------------
/resflows/layers/container.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class SequentialFlow(nn.Module):
5 | """A generalized nn.Sequential container for normalizing flows.
6 | """
7 |
8 | def __init__(self, layersList):
9 | super(SequentialFlow, self).__init__()
10 | self.chain = nn.ModuleList(layersList)
11 |
12 | def forward(self, x, logpx=None):
13 | if logpx is None:
14 | for i in range(len(self.chain)):
15 | x = self.chain[i](x)
16 | return x
17 | else:
18 | for i in range(len(self.chain)):
19 | x, logpx = self.chain[i](x, logpx)
20 | return x, logpx
21 |
22 | def inverse(self, y, logpy=None):
23 | if logpy is None:
24 | for i in range(len(self.chain) - 1, -1, -1):
25 | y = self.chain[i].inverse(y)
26 | return y
27 | else:
28 | for i in range(len(self.chain) - 1, -1, -1):
29 | y, logpy = self.chain[i].inverse(y, logpy)
30 | return y, logpy
31 |
32 |
33 | class Inverse(nn.Module):
34 |
35 | def __init__(self, flow):
36 | super(Inverse, self).__init__()
37 | self.flow = flow
38 |
39 | def forward(self, x, logpx=None):
40 | return self.flow.inverse(x, logpx)
41 |
42 | def inverse(self, y, logpy=None):
43 | return self.flow.forward(y, logpy)
44 |
--------------------------------------------------------------------------------
/resflows/layers/coupling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from . import mask_utils
4 |
5 | __all__ = ['CouplingBlock', 'ChannelCouplingBlock', 'MaskedCouplingBlock']
6 |
7 |
8 | class CouplingBlock(nn.Module):
9 | """Basic coupling layer for Tensors of shape (n,d).
10 |
11 | Forward computation:
12 | y_a = x_a
13 | y_b = y_b * exp(s(x_a)) + t(x_a)
14 | Inverse computation:
15 | x_a = y_a
16 | x_b = (y_b - t(y_a)) * exp(-s(y_a))
17 | """
18 |
19 | def __init__(self, dim, nnet, swap=False):
20 | """
21 | Args:
22 | s (nn.Module)
23 | t (nn.Module)
24 | """
25 | super(CouplingBlock, self).__init__()
26 | assert (dim % 2 == 0)
27 | self.d = dim // 2
28 | self.nnet = nnet
29 | self.swap = swap
30 |
31 | def func_s_t(self, x):
32 | f = self.nnet(x)
33 | s = f[:, :self.d]
34 | t = f[:, self.d:]
35 | return s, t
36 |
37 | def forward(self, x, logpx=None):
38 | """Forward computation of a simple coupling split on the axis=1.
39 | """
40 | x_a = x[:, :self.d] if not self.swap else x[:, self.d:]
41 | x_b = x[:, self.d:] if not self.swap else x[:, :self.d]
42 | y_a, y_b, logdetgrad = self._forward_computation(x_a, x_b)
43 | y = [y_a, y_b] if not self.swap else [y_b, y_a]
44 |
45 | if logpx is None:
46 | return torch.cat(y, dim=1)
47 | else:
48 | return torch.cat(y, dim=1), logpx - logdetgrad.view(x.size(0), -1).sum(1, keepdim=True)
49 |
50 | def inverse(self, y, logpy=None):
51 | """Inverse computation of a simple coupling split on the axis=1.
52 | """
53 | y_a = y[:, :self.d] if not self.swap else y[:, self.d:]
54 | y_b = y[:, self.d:] if not self.swap else y[:, :self.d]
55 | x_a, x_b, logdetgrad = self._inverse_computation(y_a, y_b)
56 | x = [x_a, x_b] if not self.swap else [x_b, x_a]
57 | if logpy is None:
58 | return torch.cat(x, dim=1)
59 | else:
60 | return torch.cat(x, dim=1), logpy + logdetgrad
61 |
62 | def _forward_computation(self, x_a, x_b):
63 | y_a = x_a
64 | s_a, t_a = self.func_s_t(x_a)
65 | scale = torch.sigmoid(s_a + 2.)
66 | y_b = x_b * scale + t_a
67 | logdetgrad = self._logdetgrad(scale)
68 | return y_a, y_b, logdetgrad
69 |
70 | def _inverse_computation(self, y_a, y_b):
71 | x_a = y_a
72 | s_a, t_a = self.func_s_t(y_a)
73 | scale = torch.sigmoid(s_a + 2.)
74 | x_b = (y_b - t_a) / scale
75 | logdetgrad = self._logdetgrad(scale)
76 | return x_a, x_b, logdetgrad
77 |
78 | def _logdetgrad(self, scale):
79 | """
80 | Returns:
81 | Tensor (N, 1): containing ln |det J| where J is the jacobian
82 | """
83 | return torch.log(scale).view(scale.shape[0], -1).sum(1, keepdim=True)
84 |
85 | def extra_repr(self):
86 | return 'dim={d}, swap={swap}'.format(**self.__dict__)
87 |
88 |
89 | class ChannelCouplingBlock(CouplingBlock):
90 | """Channel-wise coupling layer for images.
91 | """
92 |
93 | def __init__(self, dim, nnet, mask_type='channel0'):
94 | if mask_type == 'channel0':
95 | swap = False
96 | elif mask_type == 'channel1':
97 | swap = True
98 | else:
99 | raise ValueError('Unknown mask type.')
100 | super(ChannelCouplingBlock, self).__init__(dim, nnet, swap)
101 | self.mask_type = mask_type
102 |
103 | def extra_repr(self):
104 | return 'dim={d}, mask_type={mask_type}'.format(**self.__dict__)
105 |
106 |
107 | class MaskedCouplingBlock(nn.Module):
108 | """Coupling layer for images implemented using masks.
109 | """
110 |
111 | def __init__(self, dim, nnet, mask_type='checkerboard0'):
112 | nn.Module.__init__(self)
113 | self.d = dim
114 | self.nnet = nnet
115 | self.mask_type = mask_type
116 |
117 | def func_s_t(self, x):
118 | f = self.nnet(x)
119 | s = torch.sigmoid(f[:, :self.d] + 2.)
120 | t = f[:, self.d:]
121 | return s, t
122 |
123 | def forward(self, x, logpx=None):
124 | # get mask
125 | b = mask_utils.get_mask(x, mask_type=self.mask_type)
126 |
127 | # masked forward
128 | x_a = b * x
129 | s, t = self.func_s_t(x_a)
130 | y = (x * s + t) * (1 - b) + x_a
131 |
132 | if logpx is None:
133 | return y
134 | else:
135 | return y, logpx - self._logdetgrad(s, b)
136 |
137 | def inverse(self, y, logpy=None):
138 | # get mask
139 | b = mask_utils.get_mask(y, mask_type=self.mask_type)
140 |
141 | # masked forward
142 | y_a = b * y
143 | s, t = self.func_s_t(y_a)
144 | x = y_a + (1 - b) * (y - t) / s
145 |
146 | if logpy is None:
147 | return x
148 | else:
149 | return x, logpy + self._logdetgrad(s, b)
150 |
151 | def _logdetgrad(self, s, mask):
152 | return torch.log(s).mul_(1 - mask).view(s.shape[0], -1).sum(1, keepdim=True)
153 |
154 | def extra_repr(self):
155 | return 'dim={d}, mask_type={mask_type}'.format(**self.__dict__)
156 |
--------------------------------------------------------------------------------
/resflows/layers/elemwise.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | _DEFAULT_ALPHA = 1e-6
6 |
7 |
8 | class ZeroMeanTransform(nn.Module):
9 |
10 | def __init__(self):
11 | nn.Module.__init__(self)
12 |
13 | def forward(self, x, logpx=None):
14 | x = x - .5
15 | if logpx is None:
16 | return x
17 | return x, logpx
18 |
19 | def inverse(self, y, logpy=None):
20 | y = y + .5
21 | if logpy is None:
22 | return y
23 | return y, logpy
24 |
25 |
26 | class Normalize(nn.Module):
27 |
28 | def __init__(self, mean, std):
29 | nn.Module.__init__(self)
30 | self.register_buffer('mean', torch.as_tensor(mean, dtype=torch.float32))
31 | self.register_buffer('std', torch.as_tensor(std, dtype=torch.float32))
32 |
33 | def forward(self, x, logpx=None):
34 | y = x.clone()
35 | c = len(self.mean)
36 | y[:, :c].sub_(self.mean[None, :, None, None]).div_(self.std[None, :, None, None])
37 | if logpx is None:
38 | return y
39 | else:
40 | return y, logpx - self._logdetgrad(x)
41 |
42 | def inverse(self, y, logpy=None):
43 | x = y.clone()
44 | c = len(self.mean)
45 | x[:, :c].mul_(self.std[None, :, None, None]).add_(self.mean[None, :, None, None])
46 | if logpy is None:
47 | return x
48 | else:
49 | return x, logpy + self._logdetgrad(x)
50 |
51 | def _logdetgrad(self, x):
52 | logdetgrad = (
53 | self.std.abs().log().mul_(-1).view(1, -1, 1, 1).expand(x.shape[0], len(self.std), x.shape[2], x.shape[3])
54 | )
55 | return logdetgrad.reshape(x.shape[0], -1).sum(-1, keepdim=True)
56 |
57 |
58 | class LogitTransform(nn.Module):
59 | """
60 | The proprocessing step used in Real NVP:
61 | y = sigmoid(x) - a / (1 - 2a)
62 | x = logit(a + (1 - 2a)*y)
63 | """
64 |
65 | def __init__(self, alpha=_DEFAULT_ALPHA):
66 | nn.Module.__init__(self)
67 | self.alpha = alpha
68 |
69 | def forward(self, x, logpx=None):
70 | s = self.alpha + (1 - 2 * self.alpha) * x
71 | y = torch.log(s) - torch.log(1 - s)
72 | if logpx is None:
73 | return y
74 | return y, logpx - self._logdetgrad(x).view(x.size(0), -1).sum(1, keepdim=True)
75 |
76 | def inverse(self, y, logpy=None):
77 | x = (torch.sigmoid(y) - self.alpha) / (1 - 2 * self.alpha)
78 | if logpy is None:
79 | return x
80 | return x, logpy + self._logdetgrad(x).view(x.size(0), -1).sum(1, keepdim=True)
81 |
82 | def _logdetgrad(self, x):
83 | s = self.alpha + (1 - 2 * self.alpha) * x
84 | logdetgrad = -torch.log(s - s * s) + math.log(1 - 2 * self.alpha)
85 | return logdetgrad
86 |
87 | def __repr__(self):
88 | return ('{name}({alpha})'.format(name=self.__class__.__name__, **self.__dict__))
89 |
--------------------------------------------------------------------------------
/resflows/layers/glow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class InvertibleLinear(nn.Module):
7 |
8 | def __init__(self, dim):
9 | super(InvertibleLinear, self).__init__()
10 | self.dim = dim
11 | self.weight = nn.Parameter(torch.eye(dim)[torch.randperm(dim)])
12 |
13 | def forward(self, x, logpx=None):
14 | y = F.linear(x, self.weight)
15 | if logpx is None:
16 | return y
17 | else:
18 | return y, logpx - self._logdetgrad
19 |
20 | def inverse(self, y, logpy=None):
21 | x = F.linear(y, self.weight.inverse())
22 | if logpy is None:
23 | return x
24 | else:
25 | return x, logpy + self._logdetgrad
26 |
27 | @property
28 | def _logdetgrad(self):
29 | return torch.log(torch.abs(torch.det(self.weight)))
30 |
31 | def extra_repr(self):
32 | return 'dim={}'.format(self.dim)
33 |
34 |
35 | class InvertibleConv2d(nn.Module):
36 |
37 | def __init__(self, dim):
38 | super(InvertibleConv2d, self).__init__()
39 | self.dim = dim
40 | self.weight = nn.Parameter(torch.eye(dim)[torch.randperm(dim)])
41 |
42 | def forward(self, x, logpx=None):
43 | y = F.conv2d(x, self.weight.view(self.dim, self.dim, 1, 1))
44 | if logpx is None:
45 | return y
46 | else:
47 | return y, logpx - self._logdetgrad.expand_as(logpx) * x.shape[2] * x.shape[3]
48 |
49 | def inverse(self, y, logpy=None):
50 | x = F.conv2d(y, self.weight.inverse().view(self.dim, self.dim, 1, 1))
51 | if logpy is None:
52 | return x
53 | else:
54 | return x, logpy + self._logdetgrad.expand_as(logpy) * x.shape[2] * x.shape[3]
55 |
56 | @property
57 | def _logdetgrad(self):
58 | return torch.log(torch.abs(torch.det(self.weight)))
59 |
60 | def extra_repr(self):
61 | return 'dim={}'.format(self.dim)
62 |
--------------------------------------------------------------------------------
/resflows/layers/iresblock.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 |
6 | import logging
7 |
8 | logger = logging.getLogger()
9 |
10 | __all__ = ['iResBlock']
11 |
12 |
13 | class iResBlock(nn.Module):
14 |
15 | def __init__(
16 | self,
17 | nnet,
18 | geom_p=0.5,
19 | lamb=2.,
20 | n_power_series=None,
21 | exact_trace=False,
22 | brute_force=False,
23 | n_samples=1,
24 | n_exact_terms=2,
25 | n_dist='geometric',
26 | neumann_grad=True,
27 | grad_in_forward=False,
28 | ):
29 | """
30 | Args:
31 | nnet: a nn.Module
32 | n_power_series: number of power series. If not None, uses a biased approximation to logdet.
33 | exact_trace: if False, uses a Hutchinson trace estimator. Otherwise computes the exact full Jacobian.
34 | brute_force: Computes the exact logdet. Only available for 2D inputs.
35 | """
36 | nn.Module.__init__(self)
37 | self.nnet = nnet
38 | self.n_dist = n_dist
39 | self.geom_p = nn.Parameter(torch.tensor(np.log(geom_p) - np.log(1. - geom_p)))
40 | self.lamb = nn.Parameter(torch.tensor(lamb))
41 | self.n_samples = n_samples
42 | self.n_power_series = n_power_series
43 | self.exact_trace = exact_trace
44 | self.brute_force = brute_force
45 | self.n_exact_terms = n_exact_terms
46 | self.grad_in_forward = grad_in_forward
47 | self.neumann_grad = neumann_grad
48 |
49 | # store the samples of n.
50 | self.register_buffer('last_n_samples', torch.zeros(self.n_samples))
51 | self.register_buffer('last_firmom', torch.zeros(1))
52 | self.register_buffer('last_secmom', torch.zeros(1))
53 |
54 | def forward(self, x, logpx=None):
55 | if logpx is None:
56 | y = x + self.nnet(x)
57 | return y
58 | else:
59 | g, logdetgrad = self._logdetgrad(x)
60 | return x + g, logpx - logdetgrad
61 |
62 | def inverse(self, y, logpy=None):
63 | x = self._inverse_fixed_point(y)
64 | if logpy is None:
65 | return x
66 | else:
67 | return x, logpy + self._logdetgrad(x)[1]
68 |
69 | def _inverse_fixed_point(self, y, atol=1e-5, rtol=1e-5):
70 | x, x_prev = y - self.nnet(y), y
71 | i = 0
72 | tol = atol + y.abs() * rtol
73 | while not torch.all((x - x_prev)**2 / tol < 1):
74 | x, x_prev = y - self.nnet(x), x
75 | i += 1
76 | if i > 1000:
77 | logger.info('Iterations exceeded 1000 for inverse.')
78 | break
79 | return x
80 |
81 | def _logdetgrad(self, x):
82 | """Returns g(x) and logdet|d(x+g(x))/dx|."""
83 |
84 | with torch.enable_grad():
85 | if (self.brute_force or not self.training) and (x.ndimension() == 2 and x.shape[1] == 2):
86 | ###########################################
87 | # Brute-force compute Jacobian determinant.
88 | ###########################################
89 | x = x.requires_grad_(True)
90 | g = self.nnet(x)
91 | # Brute-force logdet only available for 2D.
92 | jac = batch_jacobian(g, x)
93 | batch_dets = (jac[:, 0, 0] + 1) * (jac[:, 1, 1] + 1) - jac[:, 0, 1] * jac[:, 1, 0]
94 | return g, torch.log(torch.abs(batch_dets)).view(-1, 1)
95 |
96 | if self.n_dist == 'geometric':
97 | geom_p = torch.sigmoid(self.geom_p).item()
98 | sample_fn = lambda m: geometric_sample(geom_p, m)
99 | rcdf_fn = lambda k, offset: geometric_1mcdf(geom_p, k, offset)
100 | elif self.n_dist == 'poisson':
101 | lamb = self.lamb.item()
102 | sample_fn = lambda m: poisson_sample(lamb, m)
103 | rcdf_fn = lambda k, offset: poisson_1mcdf(lamb, k, offset)
104 |
105 | if self.training:
106 | if self.n_power_series is None:
107 | # Unbiased estimation.
108 | lamb = self.lamb.item()
109 | n_samples = sample_fn(self.n_samples)
110 | n_power_series = max(n_samples) + self.n_exact_terms
111 | coeff_fn = lambda k: 1 / rcdf_fn(k, self.n_exact_terms) * \
112 | sum(n_samples >= k - self.n_exact_terms) / len(n_samples)
113 | else:
114 | # Truncated estimation.
115 | n_power_series = self.n_power_series
116 | coeff_fn = lambda k: 1.
117 | else:
118 | # Unbiased estimation with more exact terms.
119 | lamb = self.lamb.item()
120 | n_samples = sample_fn(self.n_samples)
121 | n_power_series = max(n_samples) + 20
122 | coeff_fn = lambda k: 1 / rcdf_fn(k, 20) * \
123 | sum(n_samples >= k - 20) / len(n_samples)
124 |
125 | if not self.exact_trace:
126 | ####################################
127 | # Power series with trace estimator.
128 | ####################################
129 | vareps = torch.randn_like(x)
130 |
131 | # Choose the type of estimator.
132 | if self.training and self.neumann_grad:
133 | estimator_fn = neumann_logdet_estimator
134 | else:
135 | estimator_fn = basic_logdet_estimator
136 |
137 | # Do backprop-in-forward to save memory.
138 | if self.training and self.grad_in_forward:
139 | g, logdetgrad = mem_eff_wrapper(
140 | estimator_fn, self.nnet, x, n_power_series, vareps, coeff_fn, self.training
141 | )
142 | else:
143 | x = x.requires_grad_(True)
144 | g = self.nnet(x)
145 | logdetgrad = estimator_fn(g, x, n_power_series, vareps, coeff_fn, self.training)
146 | else:
147 | ############################################
148 | # Power series with exact trace computation.
149 | ############################################
150 | x = x.requires_grad_(True)
151 | g = self.nnet(x)
152 | jac = batch_jacobian(g, x)
153 | logdetgrad = batch_trace(jac)
154 | jac_k = jac
155 | for k in range(2, n_power_series + 1):
156 | jac_k = torch.bmm(jac, jac_k)
157 | logdetgrad = logdetgrad + (-1)**(k + 1) / k * coeff_fn(k) * batch_trace(jac_k)
158 |
159 | if self.training and self.n_power_series is None:
160 | self.last_n_samples.copy_(torch.tensor(n_samples).to(self.last_n_samples))
161 | estimator = logdetgrad.detach()
162 | self.last_firmom.copy_(torch.mean(estimator).to(self.last_firmom))
163 | self.last_secmom.copy_(torch.mean(estimator**2).to(self.last_secmom))
164 | return g, logdetgrad.view(-1, 1)
165 |
166 | def extra_repr(self):
167 | return 'dist={}, n_samples={}, n_power_series={}, neumann_grad={}, exact_trace={}, brute_force={}'.format(
168 | self.n_dist, self.n_samples, self.n_power_series, self.neumann_grad, self.exact_trace, self.brute_force
169 | )
170 |
171 |
172 | def batch_jacobian(g, x):
173 | jac = []
174 | for d in range(g.shape[1]):
175 | jac.append(torch.autograd.grad(torch.sum(g[:, d]), x, create_graph=True)[0].view(x.shape[0], 1, x.shape[1]))
176 | return torch.cat(jac, 1)
177 |
178 |
179 | def batch_trace(M):
180 | return M.view(M.shape[0], -1)[:, ::M.shape[1] + 1].sum(1)
181 |
182 |
183 | #####################
184 | # Logdet Estimators
185 | #####################
186 | class MemoryEfficientLogDetEstimator(torch.autograd.Function):
187 |
188 | @staticmethod
189 | def forward(ctx, estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *g_params):
190 | ctx.training = training
191 | with torch.enable_grad():
192 | x = x.detach().requires_grad_(True)
193 | g = gnet(x)
194 | ctx.g = g
195 | ctx.x = x
196 | logdetgrad = estimator_fn(g, x, n_power_series, vareps, coeff_fn, training)
197 |
198 | if training:
199 | grad_x, *grad_params = torch.autograd.grad(
200 | logdetgrad.sum(), (x,) + g_params, retain_graph=True, allow_unused=True
201 | )
202 | if grad_x is None:
203 | grad_x = torch.zeros_like(x)
204 | ctx.save_for_backward(grad_x, *g_params, *grad_params)
205 |
206 | return safe_detach(g), safe_detach(logdetgrad)
207 |
208 | @staticmethod
209 | def backward(ctx, grad_g, grad_logdetgrad):
210 | training = ctx.training
211 | if not training:
212 | raise ValueError('Provide training=True if using backward.')
213 |
214 | with torch.enable_grad():
215 | grad_x, *params_and_grad = ctx.saved_tensors
216 | g, x = ctx.g, ctx.x
217 |
218 | # Precomputed gradients.
219 | g_params = params_and_grad[:len(params_and_grad) // 2]
220 | grad_params = params_and_grad[len(params_and_grad) // 2:]
221 |
222 | dg_x, *dg_params = torch.autograd.grad(g, [x] + g_params, grad_g, allow_unused=True)
223 |
224 | # Update based on gradient from logdetgrad.
225 | dL = grad_logdetgrad[0].detach()
226 | with torch.no_grad():
227 | grad_x.mul_(dL)
228 | grad_params = tuple([g.mul_(dL) if g is not None else None for g in grad_params])
229 |
230 | # Update based on gradient from g.
231 | with torch.no_grad():
232 | grad_x.add_(dg_x)
233 | grad_params = tuple([dg.add_(djac) if djac is not None else dg for dg, djac in zip(dg_params, grad_params)])
234 |
235 | return (None, None, grad_x, None, None, None, None) + grad_params
236 |
237 |
238 | def basic_logdet_estimator(g, x, n_power_series, vareps, coeff_fn, training):
239 | vjp = vareps
240 | logdetgrad = torch.tensor(0.).to(x)
241 | for k in range(1, n_power_series + 1):
242 | vjp = torch.autograd.grad(g, x, vjp, create_graph=training, retain_graph=True)[0]
243 | tr = torch.sum(vjp.view(x.shape[0], -1) * vareps.view(x.shape[0], -1), 1)
244 | delta = (-1)**(k + 1) / k * coeff_fn(k) * tr
245 | logdetgrad = logdetgrad + delta
246 | return logdetgrad
247 |
248 |
249 | def neumann_logdet_estimator(g, x, n_power_series, vareps, coeff_fn, training):
250 | vjp = vareps
251 | neumann_vjp = vareps
252 | with torch.no_grad():
253 | for k in range(1, n_power_series + 1):
254 | vjp = torch.autograd.grad(g, x, vjp, retain_graph=True)[0]
255 | neumann_vjp = neumann_vjp + (-1)**k * coeff_fn(k) * vjp
256 | vjp_jac = torch.autograd.grad(g, x, neumann_vjp, create_graph=training)[0]
257 | logdetgrad = torch.sum(vjp_jac.view(x.shape[0], -1) * vareps.view(x.shape[0], -1), 1)
258 | return logdetgrad
259 |
260 |
261 | def mem_eff_wrapper(estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training):
262 |
263 | # We need this in order to access the variables inside this module,
264 | # since we have no other way of getting variables along the execution path.
265 | if not isinstance(gnet, nn.Module):
266 | raise ValueError('g is required to be an instance of nn.Module.')
267 |
268 | return MemoryEfficientLogDetEstimator.apply(
269 | estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *list(gnet.parameters())
270 | )
271 |
272 |
273 | # -------- Helper distribution functions --------
274 | # These take python ints or floats, not PyTorch tensors.
275 |
276 |
277 | def geometric_sample(p, n_samples):
278 | return np.random.geometric(p, n_samples)
279 |
280 |
281 | def geometric_1mcdf(p, k, offset):
282 | if k <= offset:
283 | return 1.
284 | else:
285 | k = k - offset
286 | """P(n >= k)"""
287 | return (1 - p)**max(k - 1, 0)
288 |
289 |
290 | def poisson_sample(lamb, n_samples):
291 | return np.random.poisson(lamb, n_samples)
292 |
293 |
294 | def poisson_1mcdf(lamb, k, offset):
295 | if k <= offset:
296 | return 1.
297 | else:
298 | k = k - offset
299 | """P(n >= k)"""
300 | sum = 1.
301 | for i in range(1, k):
302 | sum += lamb**i / math.factorial(i)
303 | return 1 - np.exp(-lamb) * sum
304 |
305 |
306 | def sample_rademacher_like(y):
307 | return torch.randint(low=0, high=2, size=y.shape).to(y) * 2 - 1
308 |
309 |
310 | # -------------- Helper functions --------------
311 |
312 |
313 | def safe_detach(tensor):
314 | return tensor.detach().requires_grad_(tensor.requires_grad)
315 |
316 |
317 | def _flatten(sequence):
318 | flat = [p.reshape(-1) for p in sequence]
319 | return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
320 |
321 |
322 | def _flatten_convert_none_to_zeros(sequence, like_sequence):
323 | flat = [p.reshape(-1) if p is not None else torch.zeros_like(q).view(-1) for p, q in zip(sequence, like_sequence)]
324 | return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
325 |
--------------------------------------------------------------------------------
/resflows/layers/mask_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def _get_checkerboard_mask(x, swap=False):
5 | n, c, h, w = x.size()
6 |
7 | H = ((h - 1) // 2 + 1) * 2 # H = h + 1 if h is odd and h if h is even
8 | W = ((w - 1) // 2 + 1) * 2
9 |
10 | # construct checkerboard mask
11 | if not swap:
12 | mask = torch.Tensor([[1, 0], [0, 1]]).repeat(H // 2, W // 2)
13 | else:
14 | mask = torch.Tensor([[0, 1], [1, 0]]).repeat(H // 2, W // 2)
15 | mask = mask[:h, :w]
16 | mask = mask.contiguous().view(1, 1, h, w).expand(n, c, h, w).type_as(x.data)
17 |
18 | return mask
19 |
20 |
21 | def _get_channel_mask(x, swap=False):
22 | n, c, h, w = x.size()
23 | assert (c % 2 == 0)
24 |
25 | # construct channel-wise mask
26 | mask = torch.zeros(x.size())
27 | if not swap:
28 | mask[:, :c // 2] = 1
29 | else:
30 | mask[:, c // 2:] = 1
31 | return mask
32 |
33 |
34 | def get_mask(x, mask_type=None):
35 | if mask_type is None:
36 | return torch.zeros(x.size()).to(x)
37 | elif mask_type == 'channel0':
38 | return _get_channel_mask(x, swap=False)
39 | elif mask_type == 'channel1':
40 | return _get_channel_mask(x, swap=True)
41 | elif mask_type == 'checkerboard0':
42 | return _get_checkerboard_mask(x, swap=False)
43 | elif mask_type == 'checkerboard1':
44 | return _get_checkerboard_mask(x, swap=True)
45 | else:
46 | raise ValueError('Unknown mask type {}'.format(mask_type))
47 |
--------------------------------------------------------------------------------
/resflows/layers/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import Parameter
4 |
5 | __all__ = ['MovingBatchNorm1d', 'MovingBatchNorm2d']
6 |
7 |
8 | class MovingBatchNormNd(nn.Module):
9 |
10 | def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True):
11 | super(MovingBatchNormNd, self).__init__()
12 | self.num_features = num_features
13 | self.affine = affine
14 | self.eps = eps
15 | self.decay = decay
16 | self.bn_lag = bn_lag
17 | self.register_buffer('step', torch.zeros(1))
18 | if self.affine:
19 | self.bias = Parameter(torch.Tensor(num_features))
20 | else:
21 | self.register_parameter('bias', None)
22 | self.register_buffer('running_mean', torch.zeros(num_features))
23 | self.reset_parameters()
24 |
25 | @property
26 | def shape(self):
27 | raise NotImplementedError
28 |
29 | def reset_parameters(self):
30 | self.running_mean.zero_()
31 | if self.affine:
32 | self.bias.data.zero_()
33 |
34 | def forward(self, x, logpx=None):
35 | c = x.size(1)
36 | used_mean = self.running_mean.clone().detach()
37 |
38 | if self.training:
39 | # compute batch statistics
40 | x_t = x.transpose(0, 1).contiguous().view(c, -1)
41 | batch_mean = torch.mean(x_t, dim=1)
42 |
43 | # moving average
44 | if self.bn_lag > 0:
45 | used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach())
46 | used_mean /= (1. - self.bn_lag**(self.step[0] + 1))
47 |
48 | # update running estimates
49 | self.running_mean -= self.decay * (self.running_mean - batch_mean.data)
50 | self.step += 1
51 |
52 | # perform normalization
53 | used_mean = used_mean.view(*self.shape).expand_as(x)
54 |
55 | y = x - used_mean
56 |
57 | if self.affine:
58 | bias = self.bias.view(*self.shape).expand_as(x)
59 | y = y + bias
60 |
61 | if logpx is None:
62 | return y
63 | else:
64 | return y, logpx
65 |
66 | def inverse(self, y, logpy=None):
67 | used_mean = self.running_mean
68 |
69 | if self.affine:
70 | bias = self.bias.view(*self.shape).expand_as(y)
71 | y = y - bias
72 |
73 | used_mean = used_mean.view(*self.shape).expand_as(y)
74 | x = y + used_mean
75 |
76 | if logpy is None:
77 | return x
78 | else:
79 | return x, logpy
80 |
81 | def __repr__(self):
82 | return (
83 | '{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},'
84 | ' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__)
85 | )
86 |
87 |
88 | class MovingBatchNorm1d(MovingBatchNormNd):
89 |
90 | @property
91 | def shape(self):
92 | return [1, -1]
93 |
94 |
95 | class MovingBatchNorm2d(MovingBatchNormNd):
96 |
97 | @property
98 | def shape(self):
99 | return [1, -1, 1, 1]
100 |
--------------------------------------------------------------------------------
/resflows/layers/squeeze.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | __all__ = ['SqueezeLayer']
5 |
6 |
7 | class SqueezeLayer(nn.Module):
8 |
9 | def __init__(self, downscale_factor):
10 | super(SqueezeLayer, self).__init__()
11 | self.downscale_factor = downscale_factor
12 |
13 | def forward(self, x, logpx=None):
14 | squeeze_x = squeeze(x, self.downscale_factor)
15 | if logpx is None:
16 | return squeeze_x
17 | else:
18 | return squeeze_x, logpx
19 |
20 | def inverse(self, y, logpy=None):
21 | unsqueeze_y = unsqueeze(y, self.downscale_factor)
22 | if logpy is None:
23 | return unsqueeze_y
24 | else:
25 | return unsqueeze_y, logpy
26 |
27 |
28 | def unsqueeze(input, upscale_factor=2):
29 | return torch.pixel_shuffle(input, upscale_factor)
30 |
31 |
32 | def squeeze(input, downscale_factor=2):
33 | '''
34 | [:, C, H*r, W*r] -> [:, C*r^2, H, W]
35 | '''
36 | batch_size, in_channels, in_height, in_width = input.shape
37 | out_channels = in_channels * (downscale_factor**2)
38 |
39 | out_height = in_height // downscale_factor
40 | out_width = in_width // downscale_factor
41 |
42 | input_view = input.reshape(batch_size, in_channels, out_height, downscale_factor, out_width, downscale_factor)
43 |
44 | output = input_view.permute(0, 1, 3, 5, 2, 4)
45 | return output.reshape(batch_size, out_channels, out_height, out_width)
46 |
--------------------------------------------------------------------------------
/resflows/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.optim.lr_scheduler import _LRScheduler
3 |
4 |
5 | class CosineAnnealingWarmRestarts(_LRScheduler):
6 | r"""Set the learning rate of each parameter group using a cosine annealing
7 | schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
8 | is the number of epochs since the last restart and :math:`T_{i}` is the number
9 | of epochs between two warm restarts in SGDR:
10 | .. math::
11 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
12 | \cos(\frac{T_{cur}}{T_{i}}\pi))
13 | When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
14 | When :math:`T_{cur}=0`(after restart), set :math:`\eta_t=\eta_{max}`.
15 | It has been proposed in
16 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
17 | Args:
18 | optimizer (Optimizer): Wrapped optimizer.
19 | T_0 (int): Number of iterations for the first restart.
20 | T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
21 | eta_min (float, optional): Minimum learning rate. Default: 0.
22 | last_epoch (int, optional): The index of last epoch. Default: -1.
23 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
24 | https://arxiv.org/abs/1608.03983
25 | """
26 |
27 | def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1):
28 | if T_0 <= 0 or not isinstance(T_0, int):
29 | raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
30 | if T_mult < 1 or not isinstance(T_mult, int):
31 | raise ValueError("Expected integer T_mul >= 1, but got {}".format(T_mult))
32 | self.T_0 = T_0
33 | self.T_i = T_0
34 | self.T_mult = T_mult
35 | self.eta_min = eta_min
36 | super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch)
37 | self.T_cur = last_epoch
38 |
39 | def get_lr(self):
40 | return [
41 | self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
42 | for base_lr in self.base_lrs
43 | ]
44 |
45 | def step(self, epoch=None):
46 | """Step could be called after every update, i.e. if one epoch has 10 iterations
47 | (number_of_train_examples / batch_size), we should call SGDR.step(0.1), SGDR.step(0.2), etc.
48 | This function can be called in an interleaved way.
49 | Example:
50 | >>> scheduler = SGDR(optimizer, T_0, T_mult)
51 | >>> for epoch in range(20):
52 | >>> scheduler.step()
53 | >>> scheduler.step(26)
54 | >>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
55 | """
56 | if epoch is None:
57 | epoch = self.last_epoch + 1
58 | self.T_cur = self.T_cur + 1
59 | if self.T_cur >= self.T_i:
60 | self.T_cur = self.T_cur - self.T_i
61 | self.T_i = self.T_i * self.T_mult
62 | else:
63 | if epoch >= self.T_0:
64 | if self.T_mult == 1:
65 | self.T_cur = epoch % self.T_0
66 | else:
67 | n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
68 | self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / (self.T_mult - 1)
69 | self.T_i = self.T_0 * self.T_mult**(n)
70 | else:
71 | self.T_i = self.T_0
72 | self.T_cur = epoch
73 | self.last_epoch = math.floor(epoch)
74 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
75 | param_group['lr'] = lr
76 |
--------------------------------------------------------------------------------
/resflows/optimizers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 |
5 |
6 | class Adam(Optimizer):
7 | """Implements Adam algorithm.
8 |
9 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
10 |
11 | Arguments:
12 | params (iterable): iterable of parameters to optimize or dicts defining
13 | parameter groups
14 | lr (float, optional): learning rate (default: 1e-3)
15 | betas (Tuple[float, float], optional): coefficients used for computing
16 | running averages of gradient and its square (default: (0.9, 0.999))
17 | eps (float, optional): term added to the denominator to improve
18 | numerical stability (default: 1e-8)
19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
20 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
21 | algorithm from the paper `On the Convergence of Adam and Beyond`_
22 | (default: False)
23 |
24 | .. _Adam\: A Method for Stochastic Optimization:
25 | https://arxiv.org/abs/1412.6980
26 | .. _On the Convergence of Adam and Beyond:
27 | https://openreview.net/forum?id=ryQu7f-RZ
28 | """
29 |
30 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False):
31 | if not 0.0 <= lr:
32 | raise ValueError("Invalid learning rate: {}".format(lr))
33 | if not 0.0 <= eps:
34 | raise ValueError("Invalid epsilon value: {}".format(eps))
35 | if not 0.0 <= betas[0] < 1.0:
36 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
37 | if not 0.0 <= betas[1] < 1.0:
38 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
39 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
40 | super(Adam, self).__init__(params, defaults)
41 |
42 | def __setstate__(self, state):
43 | super(Adam, self).__setstate__(state)
44 | for group in self.param_groups:
45 | group.setdefault('amsgrad', False)
46 |
47 | def step(self, closure=None):
48 | """Performs a single optimization step.
49 |
50 | Arguments:
51 | closure (callable, optional): A closure that reevaluates the model
52 | and returns the loss.
53 | """
54 | loss = None
55 | if closure is not None:
56 | loss = closure()
57 |
58 | for group in self.param_groups:
59 | for p in group['params']:
60 | if p.grad is None:
61 | continue
62 | grad = p.grad.data
63 | if grad.is_sparse:
64 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
65 | amsgrad = group['amsgrad']
66 |
67 | state = self.state[p]
68 |
69 | # State initialization
70 | if len(state) == 0:
71 | state['step'] = 0
72 | # Exponential moving average of gradient values
73 | state['exp_avg'] = torch.zeros_like(p.data)
74 | # Exponential moving average of squared gradient values
75 | state['exp_avg_sq'] = torch.zeros_like(p.data)
76 | if amsgrad:
77 | # Maintains max of all exp. moving avg. of sq. grad. values
78 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
79 |
80 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
81 | if amsgrad:
82 | max_exp_avg_sq = state['max_exp_avg_sq']
83 | beta1, beta2 = group['betas']
84 |
85 | state['step'] += 1
86 |
87 | # Decay the first and second moment running average coefficient
88 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
89 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
90 | if amsgrad:
91 | # Maintains the maximum of all 2nd moment running avg. till now
92 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
93 | # Use the max. for normalizing running avg. of gradient
94 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
95 | else:
96 | denom = exp_avg_sq.sqrt().add_(group['eps'])
97 |
98 | bias_correction1 = 1 - beta1**state['step']
99 | bias_correction2 = 1 - beta2**state['step']
100 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
101 |
102 | p.data.addcdiv_(-step_size, exp_avg, denom)
103 |
104 | if group['weight_decay'] != 0:
105 | p.data.add(-step_size * group['weight_decay'], p.data)
106 |
107 | return loss
108 |
109 |
110 | class Adamax(Optimizer):
111 | """Implements Adamax algorithm (a variant of Adam based on infinity norm).
112 |
113 | It has been proposed in `Adam: A Method for Stochastic Optimization`__.
114 |
115 | Arguments:
116 | params (iterable): iterable of parameters to optimize or dicts defining
117 | parameter groups
118 | lr (float, optional): learning rate (default: 2e-3)
119 | betas (Tuple[float, float], optional): coefficients used for computing
120 | running averages of gradient and its square
121 | eps (float, optional): term added to the denominator to improve
122 | numerical stability (default: 1e-8)
123 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
124 |
125 | __ https://arxiv.org/abs/1412.6980
126 | """
127 |
128 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
129 | if not 0.0 <= lr:
130 | raise ValueError("Invalid learning rate: {}".format(lr))
131 | if not 0.0 <= eps:
132 | raise ValueError("Invalid epsilon value: {}".format(eps))
133 | if not 0.0 <= betas[0] < 1.0:
134 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
135 | if not 0.0 <= betas[1] < 1.0:
136 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
137 | if not 0.0 <= weight_decay:
138 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
139 |
140 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
141 | super(Adamax, self).__init__(params, defaults)
142 |
143 | def step(self, closure=None):
144 | """Performs a single optimization step.
145 |
146 | Arguments:
147 | closure (callable, optional): A closure that reevaluates the model
148 | and returns the loss.
149 | """
150 | loss = None
151 | if closure is not None:
152 | loss = closure()
153 |
154 | for group in self.param_groups:
155 | for p in group['params']:
156 | if p.grad is None:
157 | continue
158 | grad = p.grad.data
159 | if grad.is_sparse:
160 | raise RuntimeError('Adamax does not support sparse gradients')
161 | state = self.state[p]
162 |
163 | # State initialization
164 | if len(state) == 0:
165 | state['step'] = 0
166 | state['exp_avg'] = torch.zeros_like(p.data)
167 | state['exp_inf'] = torch.zeros_like(p.data)
168 |
169 | exp_avg, exp_inf = state['exp_avg'], state['exp_inf']
170 | beta1, beta2 = group['betas']
171 | eps = group['eps']
172 |
173 | state['step'] += 1
174 |
175 | # Update biased first moment estimate.
176 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
177 | # Update the exponentially weighted infinity norm.
178 | norm_buf = torch.cat([exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)], 0)
179 | torch.max(norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long()))
180 |
181 | bias_correction = 1 - beta1**state['step']
182 | clr = group['lr'] / bias_correction
183 |
184 | p.data.addcdiv_(-clr, exp_avg, exp_inf)
185 |
186 | if group['weight_decay'] != 0:
187 | p.data.add(-clr * group['weight_decay'], p.data)
188 |
189 | return loss
190 |
191 |
192 | class RMSprop(Optimizer):
193 | """Implements RMSprop algorithm.
194 |
195 | Proposed by G. Hinton in his
196 | `course `_.
197 |
198 | The centered version first appears in `Generating Sequences
199 | With Recurrent Neural Networks `_.
200 |
201 | Arguments:
202 | params (iterable): iterable of parameters to optimize or dicts defining
203 | parameter groups
204 | lr (float, optional): learning rate (default: 1e-2)
205 | momentum (float, optional): momentum factor (default: 0)
206 | alpha (float, optional): smoothing constant (default: 0.99)
207 | eps (float, optional): term added to the denominator to improve
208 | numerical stability (default: 1e-8)
209 | centered (bool, optional) : if ``True``, compute the centered RMSProp,
210 | the gradient is normalized by an estimation of its variance
211 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
212 |
213 | """
214 |
215 | def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):
216 | if not 0.0 <= lr:
217 | raise ValueError("Invalid learning rate: {}".format(lr))
218 | if not 0.0 <= eps:
219 | raise ValueError("Invalid epsilon value: {}".format(eps))
220 | if not 0.0 <= momentum:
221 | raise ValueError("Invalid momentum value: {}".format(momentum))
222 | if not 0.0 <= weight_decay:
223 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
224 | if not 0.0 <= alpha:
225 | raise ValueError("Invalid alpha value: {}".format(alpha))
226 |
227 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
228 | super(RMSprop, self).__init__(params, defaults)
229 |
230 | def __setstate__(self, state):
231 | super(RMSprop, self).__setstate__(state)
232 | for group in self.param_groups:
233 | group.setdefault('momentum', 0)
234 | group.setdefault('centered', False)
235 |
236 | def step(self, closure=None):
237 | """Performs a single optimization step.
238 |
239 | Arguments:
240 | closure (callable, optional): A closure that reevaluates the model
241 | and returns the loss.
242 | """
243 | loss = None
244 | if closure is not None:
245 | loss = closure()
246 |
247 | for group in self.param_groups:
248 | for p in group['params']:
249 | if p.grad is None:
250 | continue
251 | grad = p.grad.data
252 | if grad.is_sparse:
253 | raise RuntimeError('RMSprop does not support sparse gradients')
254 | state = self.state[p]
255 |
256 | # State initialization
257 | if len(state) == 0:
258 | state['step'] = 0
259 | state['square_avg'] = torch.zeros_like(p.data)
260 | if group['momentum'] > 0:
261 | state['momentum_buffer'] = torch.zeros_like(p.data)
262 | if group['centered']:
263 | state['grad_avg'] = torch.zeros_like(p.data)
264 |
265 | square_avg = state['square_avg']
266 | alpha = group['alpha']
267 |
268 | state['step'] += 1
269 |
270 | square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad)
271 |
272 | if group['centered']:
273 | grad_avg = state['grad_avg']
274 | grad_avg.mul_(alpha).add_(1 - alpha, grad)
275 | avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt().add_(group['eps'])
276 | else:
277 | avg = square_avg.sqrt().add_(group['eps'])
278 |
279 | if group['momentum'] > 0:
280 | buf = state['momentum_buffer']
281 | buf.mul_(group['momentum']).addcdiv_(grad, avg)
282 | p.data.add_(-group['lr'], buf)
283 | else:
284 | p.data.addcdiv_(-group['lr'], grad, avg)
285 |
286 | if group['weight_decay'] != 0:
287 | p.data.add(-group['lr'] * group['weight_decay'], p.data)
288 |
289 | return loss
290 |
--------------------------------------------------------------------------------
/resflows/resflow.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | import resflows.layers as layers
6 | import resflows.layers.base as base_layers
7 |
8 | ACT_FNS = {
9 | 'softplus': lambda b: nn.Softplus(),
10 | 'elu': lambda b: nn.ELU(inplace=b),
11 | 'swish': lambda b: base_layers.Swish(),
12 | 'lcube': lambda b: base_layers.LipschitzCube(),
13 | 'identity': lambda b: base_layers.Identity(),
14 | 'relu': lambda b: nn.ReLU(inplace=b),
15 | }
16 |
17 |
18 | class ResidualFlow(nn.Module):
19 |
20 | def __init__(
21 | self,
22 | input_size,
23 | n_blocks=[16, 16],
24 | intermediate_dim=64,
25 | factor_out=True,
26 | quadratic=False,
27 | init_layer=None,
28 | actnorm=False,
29 | fc_actnorm=False,
30 | batchnorm=False,
31 | dropout=0,
32 | fc=False,
33 | coeff=0.9,
34 | vnorms='122f',
35 | n_lipschitz_iters=None,
36 | sn_atol=None,
37 | sn_rtol=None,
38 | n_power_series=5,
39 | n_dist='geometric',
40 | n_samples=1,
41 | kernels='3-1-3',
42 | activation_fn='elu',
43 | fc_end=True,
44 | fc_idim=128,
45 | n_exact_terms=0,
46 | preact=False,
47 | neumann_grad=True,
48 | grad_in_forward=False,
49 | first_resblock=False,
50 | learn_p=False,
51 | classification=False,
52 | classification_hdim=64,
53 | n_classes=10,
54 | block_type='resblock',
55 | ):
56 | super(ResidualFlow, self).__init__()
57 | self.n_scale = min(len(n_blocks), self._calc_n_scale(input_size))
58 | self.n_blocks = n_blocks
59 | self.intermediate_dim = intermediate_dim
60 | self.factor_out = factor_out
61 | self.quadratic = quadratic
62 | self.init_layer = init_layer
63 | self.actnorm = actnorm
64 | self.fc_actnorm = fc_actnorm
65 | self.batchnorm = batchnorm
66 | self.dropout = dropout
67 | self.fc = fc
68 | self.coeff = coeff
69 | self.vnorms = vnorms
70 | self.n_lipschitz_iters = n_lipschitz_iters
71 | self.sn_atol = sn_atol
72 | self.sn_rtol = sn_rtol
73 | self.n_power_series = n_power_series
74 | self.n_dist = n_dist
75 | self.n_samples = n_samples
76 | self.kernels = kernels
77 | self.activation_fn = activation_fn
78 | self.fc_end = fc_end
79 | self.fc_idim = fc_idim
80 | self.n_exact_terms = n_exact_terms
81 | self.preact = preact
82 | self.neumann_grad = neumann_grad
83 | self.grad_in_forward = grad_in_forward
84 | self.first_resblock = first_resblock
85 | self.learn_p = learn_p
86 | self.classification = classification
87 | self.classification_hdim = classification_hdim
88 | self.n_classes = n_classes
89 | self.block_type = block_type
90 |
91 | if not self.n_scale > 0:
92 | raise ValueError('Could not compute number of scales for input of' 'size (%d,%d,%d,%d)' % input_size)
93 |
94 | self.transforms = self._build_net(input_size)
95 |
96 | self.dims = [o[1:] for o in self.calc_output_size(input_size)]
97 |
98 | if self.classification:
99 | self.build_multiscale_classifier(input_size)
100 |
101 | def _build_net(self, input_size):
102 | _, c, h, w = input_size
103 | transforms = []
104 | _stacked_blocks = StackediResBlocks if self.block_type == 'resblock' else StackedCouplingBlocks
105 | for i in range(self.n_scale):
106 | transforms.append(
107 | _stacked_blocks(
108 | initial_size=(c, h, w),
109 | idim=self.intermediate_dim,
110 | squeeze=(i < self.n_scale - 1), # don't squeeze last layer
111 | init_layer=self.init_layer if i == 0 else None,
112 | n_blocks=self.n_blocks[i],
113 | quadratic=self.quadratic,
114 | actnorm=self.actnorm,
115 | fc_actnorm=self.fc_actnorm,
116 | batchnorm=self.batchnorm,
117 | dropout=self.dropout,
118 | fc=self.fc,
119 | coeff=self.coeff,
120 | vnorms=self.vnorms,
121 | n_lipschitz_iters=self.n_lipschitz_iters,
122 | sn_atol=self.sn_atol,
123 | sn_rtol=self.sn_rtol,
124 | n_power_series=self.n_power_series,
125 | n_dist=self.n_dist,
126 | n_samples=self.n_samples,
127 | kernels=self.kernels,
128 | activation_fn=self.activation_fn,
129 | fc_end=self.fc_end,
130 | fc_idim=self.fc_idim,
131 | n_exact_terms=self.n_exact_terms,
132 | preact=self.preact,
133 | neumann_grad=self.neumann_grad,
134 | grad_in_forward=self.grad_in_forward,
135 | first_resblock=self.first_resblock and (i == 0),
136 | learn_p=self.learn_p,
137 | )
138 | )
139 | c, h, w = c * 2 if self.factor_out else c * 4, h // 2, w // 2
140 | return nn.ModuleList(transforms)
141 |
142 | def _calc_n_scale(self, input_size):
143 | _, _, h, w = input_size
144 | n_scale = 0
145 | while h >= 4 and w >= 4:
146 | n_scale += 1
147 | h = h // 2
148 | w = w // 2
149 | return n_scale
150 |
151 | def calc_output_size(self, input_size):
152 | n, c, h, w = input_size
153 | if not self.factor_out:
154 | k = self.n_scale - 1
155 | return [[n, c * 4**k, h // 2**k, w // 2**k]]
156 | output_sizes = []
157 | for i in range(self.n_scale):
158 | if i < self.n_scale - 1:
159 | c *= 2
160 | h //= 2
161 | w //= 2
162 | output_sizes.append((n, c, h, w))
163 | else:
164 | output_sizes.append((n, c, h, w))
165 | return tuple(output_sizes)
166 |
167 | def build_multiscale_classifier(self, input_size):
168 | n, c, h, w = input_size
169 | hidden_shapes = []
170 | for i in range(self.n_scale):
171 | if i < self.n_scale - 1:
172 | c *= 2 if self.factor_out else 4
173 | h //= 2
174 | w //= 2
175 | hidden_shapes.append((n, c, h, w))
176 |
177 | classification_heads = []
178 | for i, hshape in enumerate(hidden_shapes):
179 | classification_heads.append(
180 | nn.Sequential(
181 | nn.Conv2d(hshape[1], self.classification_hdim, 3, 1, 1),
182 | layers.ActNorm2d(self.classification_hdim),
183 | nn.ReLU(inplace=True),
184 | nn.AdaptiveAvgPool2d((1, 1)),
185 | )
186 | )
187 | self.classification_heads = nn.ModuleList(classification_heads)
188 | self.logit_layer = nn.Linear(self.classification_hdim * len(classification_heads), self.n_classes)
189 |
190 | def forward(self, x, logpx=None, inverse=False, classify=False):
191 | if inverse:
192 | return self.inverse(x, logpx)
193 | out = []
194 | if classify: class_outs = []
195 | for idx in range(len(self.transforms)):
196 | if logpx is not None:
197 | x, logpx = self.transforms[idx].forward(x, logpx)
198 | else:
199 | x = self.transforms[idx].forward(x)
200 | if self.factor_out and (idx < len(self.transforms) - 1):
201 | d = x.size(1) // 2
202 | x, f = x[:, :d], x[:, d:]
203 | out.append(f)
204 |
205 | # Handle classification.
206 | if classify:
207 | if self.factor_out:
208 | class_outs.append(self.classification_heads[idx](f))
209 | else:
210 | class_outs.append(self.classification_heads[idx](x))
211 |
212 | out.append(x)
213 | out = torch.cat([o.view(o.size()[0], -1) for o in out], 1)
214 | output = out if logpx is None else (out, logpx)
215 | if classify:
216 | h = torch.cat(class_outs, dim=1).squeeze(-1).squeeze(-1)
217 | logits = self.logit_layer(h)
218 | return output, logits
219 | else:
220 | return output
221 |
222 | def inverse(self, z, logpz=None):
223 | if self.factor_out:
224 | z = z.view(z.shape[0], -1)
225 | zs = []
226 | i = 0
227 | for dims in self.dims:
228 | s = np.prod(dims)
229 | zs.append(z[:, i:i + s])
230 | i += s
231 | zs = [_z.view(_z.size()[0], *zsize) for _z, zsize in zip(zs, self.dims)]
232 |
233 | if logpz is None:
234 | z_prev = self.transforms[-1].inverse(zs[-1])
235 | for idx in range(len(self.transforms) - 2, -1, -1):
236 | z_prev = torch.cat((z_prev, zs[idx]), dim=1)
237 | z_prev = self.transforms[idx].inverse(z_prev)
238 | return z_prev
239 | else:
240 | z_prev, logpz = self.transforms[-1].inverse(zs[-1], logpz)
241 | for idx in range(len(self.transforms) - 2, -1, -1):
242 | z_prev = torch.cat((z_prev, zs[idx]), dim=1)
243 | z_prev, logpz = self.transforms[idx].inverse(z_prev, logpz)
244 | return z_prev, logpz
245 | else:
246 | z = z.view(z.shape[0], *self.dims[-1])
247 | for idx in range(len(self.transforms) - 1, -1, -1):
248 | if logpz is None:
249 | z = self.transforms[idx].inverse(z)
250 | else:
251 | z, logpz = self.transforms[idx].inverse(z, logpz)
252 | return z if logpz is None else (z, logpz)
253 |
254 |
255 | class StackediResBlocks(layers.SequentialFlow):
256 |
257 | def __init__(
258 | self,
259 | initial_size,
260 | idim,
261 | squeeze=True,
262 | init_layer=None,
263 | n_blocks=1,
264 | quadratic=False,
265 | actnorm=False,
266 | fc_actnorm=False,
267 | batchnorm=False,
268 | dropout=0,
269 | fc=False,
270 | coeff=0.9,
271 | vnorms='122f',
272 | n_lipschitz_iters=None,
273 | sn_atol=None,
274 | sn_rtol=None,
275 | n_power_series=5,
276 | n_dist='geometric',
277 | n_samples=1,
278 | kernels='3-1-3',
279 | activation_fn='elu',
280 | fc_end=True,
281 | fc_nblocks=4,
282 | fc_idim=128,
283 | n_exact_terms=0,
284 | preact=False,
285 | neumann_grad=True,
286 | grad_in_forward=False,
287 | first_resblock=False,
288 | learn_p=False,
289 | ):
290 |
291 | chain = []
292 |
293 | # Parse vnorms
294 | ps = []
295 | for p in vnorms:
296 | if p == 'f':
297 | ps.append(float('inf'))
298 | else:
299 | ps.append(float(p))
300 | domains, codomains = ps[:-1], ps[1:]
301 | assert len(domains) == len(kernels.split('-'))
302 |
303 | def _actnorm(size, fc):
304 | if fc:
305 | return FCWrapper(layers.ActNorm1d(size[0] * size[1] * size[2]))
306 | else:
307 | return layers.ActNorm2d(size[0])
308 |
309 | def _quadratic_layer(initial_size, fc):
310 | if fc:
311 | c, h, w = initial_size
312 | dim = c * h * w
313 | return FCWrapper(layers.InvertibleLinear(dim))
314 | else:
315 | return layers.InvertibleConv2d(initial_size[0])
316 |
317 | def _lipschitz_layer(fc):
318 | return base_layers.get_linear if fc else base_layers.get_conv2d
319 |
320 | def _resblock(initial_size, fc, idim=idim, first_resblock=False):
321 | if fc:
322 | return layers.iResBlock(
323 | FCNet(
324 | input_shape=initial_size,
325 | idim=idim,
326 | lipschitz_layer=_lipschitz_layer(True),
327 | nhidden=len(kernels.split('-')) - 1,
328 | coeff=coeff,
329 | domains=domains,
330 | codomains=codomains,
331 | n_iterations=n_lipschitz_iters,
332 | activation_fn=activation_fn,
333 | preact=preact,
334 | dropout=dropout,
335 | sn_atol=sn_atol,
336 | sn_rtol=sn_rtol,
337 | learn_p=learn_p,
338 | ),
339 | n_power_series=n_power_series,
340 | n_dist=n_dist,
341 | n_samples=n_samples,
342 | n_exact_terms=n_exact_terms,
343 | neumann_grad=neumann_grad,
344 | grad_in_forward=grad_in_forward,
345 | )
346 | else:
347 | ks = list(map(int, kernels.split('-')))
348 | if learn_p:
349 | _domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(ks))]
350 | _codomains = _domains[1:] + [_domains[0]]
351 | else:
352 | _domains = domains
353 | _codomains = codomains
354 | nnet = []
355 | if not first_resblock and preact:
356 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
357 | nnet.append(ACT_FNS[activation_fn](False))
358 | nnet.append(
359 | _lipschitz_layer(fc)(
360 | initial_size[0], idim, ks[0], 1, ks[0] // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
361 | domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol
362 | )
363 | )
364 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
365 | nnet.append(ACT_FNS[activation_fn](True))
366 | for i, k in enumerate(ks[1:-1]):
367 | nnet.append(
368 | _lipschitz_layer(fc)(
369 | idim, idim, k, 1, k // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
370 | domain=_domains[i + 1], codomain=_codomains[i + 1], atol=sn_atol, rtol=sn_rtol
371 | )
372 | )
373 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
374 | nnet.append(ACT_FNS[activation_fn](True))
375 | if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True))
376 | nnet.append(
377 | _lipschitz_layer(fc)(
378 | idim, initial_size[0], ks[-1], 1, ks[-1] // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
379 | domain=_domains[-1], codomain=_codomains[-1], atol=sn_atol, rtol=sn_rtol
380 | )
381 | )
382 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
383 | return layers.iResBlock(
384 | nn.Sequential(*nnet),
385 | n_power_series=n_power_series,
386 | n_dist=n_dist,
387 | n_samples=n_samples,
388 | n_exact_terms=n_exact_terms,
389 | neumann_grad=neumann_grad,
390 | grad_in_forward=grad_in_forward,
391 | )
392 |
393 | if init_layer is not None: chain.append(init_layer)
394 | if first_resblock and actnorm: chain.append(_actnorm(initial_size, fc))
395 | if first_resblock and fc_actnorm: chain.append(_actnorm(initial_size, True))
396 |
397 | if squeeze:
398 | c, h, w = initial_size
399 | for i in range(n_blocks):
400 | if quadratic: chain.append(_quadratic_layer(initial_size, fc))
401 | chain.append(_resblock(initial_size, fc, first_resblock=first_resblock and (i == 0)))
402 | if actnorm: chain.append(_actnorm(initial_size, fc))
403 | if fc_actnorm: chain.append(_actnorm(initial_size, True))
404 | chain.append(layers.SqueezeLayer(2))
405 | else:
406 | for _ in range(n_blocks):
407 | if quadratic: chain.append(_quadratic_layer(initial_size, fc))
408 | chain.append(_resblock(initial_size, fc))
409 | if actnorm: chain.append(_actnorm(initial_size, fc))
410 | if fc_actnorm: chain.append(_actnorm(initial_size, True))
411 | # Use four fully connected layers at the end.
412 | if fc_end:
413 | for _ in range(fc_nblocks):
414 | chain.append(_resblock(initial_size, True, fc_idim))
415 | if actnorm or fc_actnorm: chain.append(_actnorm(initial_size, True))
416 |
417 | super(StackediResBlocks, self).__init__(chain)
418 |
419 |
420 | class FCNet(nn.Module):
421 |
422 | def __init__(
423 | self, input_shape, idim, lipschitz_layer, nhidden, coeff, domains, codomains, n_iterations, activation_fn,
424 | preact, dropout, sn_atol, sn_rtol, learn_p, div_in=1
425 | ):
426 | super(FCNet, self).__init__()
427 | self.input_shape = input_shape
428 | c, h, w = self.input_shape
429 | dim = c * h * w
430 | nnet = []
431 | last_dim = dim // div_in
432 | if preact: nnet.append(ACT_FNS[activation_fn](False))
433 | if learn_p:
434 | domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(domains))]
435 | codomains = domains[1:] + [domains[0]]
436 | for i in range(nhidden):
437 | nnet.append(
438 | lipschitz_layer(last_dim, idim) if lipschitz_layer == nn.Linear else lipschitz_layer(
439 | last_dim, idim, coeff=coeff, n_iterations=n_iterations, domain=domains[i], codomain=codomains[i],
440 | atol=sn_atol, rtol=sn_rtol
441 | )
442 | )
443 | nnet.append(ACT_FNS[activation_fn](True))
444 | last_dim = idim
445 | if dropout: nnet.append(nn.Dropout(dropout, inplace=True))
446 | nnet.append(
447 | lipschitz_layer(last_dim, dim) if lipschitz_layer == nn.Linear else lipschitz_layer(
448 | last_dim, dim, coeff=coeff, n_iterations=n_iterations, domain=domains[-1], codomain=codomains[-1],
449 | atol=sn_atol, rtol=sn_rtol
450 | )
451 | )
452 | self.nnet = nn.Sequential(*nnet)
453 |
454 | def forward(self, x):
455 | x = x.view(x.shape[0], -1)
456 | y = self.nnet(x)
457 | return y.view(y.shape[0], *self.input_shape)
458 |
459 |
460 | class FCWrapper(nn.Module):
461 |
462 | def __init__(self, fc_module):
463 | super(FCWrapper, self).__init__()
464 | self.fc_module = fc_module
465 |
466 | def forward(self, x, logpx=None):
467 | shape = x.shape
468 | x = x.view(x.shape[0], -1)
469 | if logpx is None:
470 | y = self.fc_module(x)
471 | return y.view(*shape)
472 | else:
473 | y, logpy = self.fc_module(x, logpx)
474 | return y.view(*shape), logpy
475 |
476 | def inverse(self, y, logpy=None):
477 | shape = y.shape
478 | y = y.view(y.shape[0], -1)
479 | if logpy is None:
480 | x = self.fc_module.inverse(y)
481 | return x.view(*shape)
482 | else:
483 | x, logpx = self.fc_module.inverse(y, logpy)
484 | return x.view(*shape), logpx
485 |
486 |
487 | class StackedCouplingBlocks(layers.SequentialFlow):
488 |
489 | def __init__(
490 | self,
491 | initial_size,
492 | idim,
493 | squeeze=True,
494 | init_layer=None,
495 | n_blocks=1,
496 | quadratic=False,
497 | actnorm=False,
498 | fc_actnorm=False,
499 | batchnorm=False,
500 | dropout=0,
501 | fc=False,
502 | coeff=0.9,
503 | vnorms='122f',
504 | n_lipschitz_iters=None,
505 | sn_atol=None,
506 | sn_rtol=None,
507 | n_power_series=5,
508 | n_dist='geometric',
509 | n_samples=1,
510 | kernels='3-1-3',
511 | activation_fn='elu',
512 | fc_end=True,
513 | fc_nblocks=4,
514 | fc_idim=128,
515 | n_exact_terms=0,
516 | preact=False,
517 | neumann_grad=True,
518 | grad_in_forward=False,
519 | first_resblock=False,
520 | learn_p=False,
521 | ):
522 |
523 | # yapf: disable
524 | class nonloc_scope: pass
525 | nonloc_scope.swap = True
526 | # yapf: enable
527 |
528 | chain = []
529 |
530 | def _actnorm(size, fc):
531 | if fc:
532 | return FCWrapper(layers.ActNorm1d(size[0] * size[1] * size[2]))
533 | else:
534 | return layers.ActNorm2d(size[0])
535 |
536 | def _quadratic_layer(initial_size, fc):
537 | if fc:
538 | c, h, w = initial_size
539 | dim = c * h * w
540 | return FCWrapper(layers.InvertibleLinear(dim))
541 | else:
542 | return layers.InvertibleConv2d(initial_size[0])
543 |
544 | def _weight_layer(fc):
545 | return nn.Linear if fc else nn.Conv2d
546 |
547 | def _resblock(initial_size, fc, idim=idim, first_resblock=False):
548 | if fc:
549 | nonloc_scope.swap = not nonloc_scope.swap
550 | return layers.CouplingBlock(
551 | initial_size[0],
552 | FCNet(
553 | input_shape=initial_size,
554 | idim=idim,
555 | lipschitz_layer=_weight_layer(True),
556 | nhidden=len(kernels.split('-')) - 1,
557 | activation_fn=activation_fn,
558 | preact=preact,
559 | dropout=dropout,
560 | coeff=None,
561 | domains=None,
562 | codomains=None,
563 | n_iterations=None,
564 | sn_atol=None,
565 | sn_rtol=None,
566 | learn_p=None,
567 | div_in=2,
568 | ),
569 | swap=nonloc_scope.swap,
570 | )
571 | else:
572 | ks = list(map(int, kernels.split('-')))
573 |
574 | if init_layer is None:
575 | _block = layers.ChannelCouplingBlock
576 | _mask_type = 'channel'
577 | div_in = 2
578 | mult_out = 1
579 | else:
580 | _block = layers.MaskedCouplingBlock
581 | _mask_type = 'checkerboard'
582 | div_in = 1
583 | mult_out = 2
584 |
585 | nonloc_scope.swap = not nonloc_scope.swap
586 | _mask_type += '1' if nonloc_scope.swap else '0'
587 |
588 | nnet = []
589 | if not first_resblock and preact:
590 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
591 | nnet.append(ACT_FNS[activation_fn](False))
592 | nnet.append(_weight_layer(fc)(initial_size[0] // div_in, idim, ks[0], 1, ks[0] // 2))
593 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
594 | nnet.append(ACT_FNS[activation_fn](True))
595 | for i, k in enumerate(ks[1:-1]):
596 | nnet.append(_weight_layer(fc)(idim, idim, k, 1, k // 2))
597 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
598 | nnet.append(ACT_FNS[activation_fn](True))
599 | if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True))
600 | nnet.append(_weight_layer(fc)(idim, initial_size[0] * mult_out, ks[-1], 1, ks[-1] // 2))
601 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
602 |
603 | return _block(initial_size[0], nn.Sequential(*nnet), mask_type=_mask_type)
604 |
605 | if init_layer is not None: chain.append(init_layer)
606 | if first_resblock and actnorm: chain.append(_actnorm(initial_size, fc))
607 | if first_resblock and fc_actnorm: chain.append(_actnorm(initial_size, True))
608 |
609 | if squeeze:
610 | c, h, w = initial_size
611 | for i in range(n_blocks):
612 | if quadratic: chain.append(_quadratic_layer(initial_size, fc))
613 | chain.append(_resblock(initial_size, fc, first_resblock=first_resblock and (i == 0)))
614 | if actnorm: chain.append(_actnorm(initial_size, fc))
615 | if fc_actnorm: chain.append(_actnorm(initial_size, True))
616 | chain.append(layers.SqueezeLayer(2))
617 | else:
618 | for _ in range(n_blocks):
619 | if quadratic: chain.append(_quadratic_layer(initial_size, fc))
620 | chain.append(_resblock(initial_size, fc))
621 | if actnorm: chain.append(_actnorm(initial_size, fc))
622 | if fc_actnorm: chain.append(_actnorm(initial_size, True))
623 | # Use four fully connected layers at the end.
624 | if fc_end:
625 | for _ in range(fc_nblocks):
626 | chain.append(_resblock(initial_size, True, fc_idim))
627 | if actnorm or fc_actnorm: chain.append(_actnorm(initial_size, True))
628 |
629 | super(StackedCouplingBlocks, self).__init__(chain)
630 |
--------------------------------------------------------------------------------
/resflows/toy_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sklearn
3 | import sklearn.datasets
4 | from sklearn.utils import shuffle as util_shuffle
5 |
6 |
7 | # Dataset iterator
8 | def inf_train_gen(data, batch_size=200):
9 |
10 | if data == "swissroll":
11 | data = sklearn.datasets.make_swiss_roll(n_samples=batch_size, noise=1.0)[0]
12 | data = data.astype("float32")[:, [0, 2]]
13 | data /= 5
14 | return data
15 |
16 | elif data == "circles":
17 | data = sklearn.datasets.make_circles(n_samples=batch_size, factor=.5, noise=0.08)[0]
18 | data = data.astype("float32")
19 | data *= 3
20 | return data
21 |
22 | elif data == "rings":
23 | n_samples4 = n_samples3 = n_samples2 = batch_size // 4
24 | n_samples1 = batch_size - n_samples4 - n_samples3 - n_samples2
25 |
26 | # so as not to have the first point = last point, we set endpoint=False
27 | linspace4 = np.linspace(0, 2 * np.pi, n_samples4, endpoint=False)
28 | linspace3 = np.linspace(0, 2 * np.pi, n_samples3, endpoint=False)
29 | linspace2 = np.linspace(0, 2 * np.pi, n_samples2, endpoint=False)
30 | linspace1 = np.linspace(0, 2 * np.pi, n_samples1, endpoint=False)
31 |
32 | circ4_x = np.cos(linspace4)
33 | circ4_y = np.sin(linspace4)
34 | circ3_x = np.cos(linspace4) * 0.75
35 | circ3_y = np.sin(linspace3) * 0.75
36 | circ2_x = np.cos(linspace2) * 0.5
37 | circ2_y = np.sin(linspace2) * 0.5
38 | circ1_x = np.cos(linspace1) * 0.25
39 | circ1_y = np.sin(linspace1) * 0.25
40 |
41 | X = np.vstack([
42 | np.hstack([circ4_x, circ3_x, circ2_x, circ1_x]),
43 | np.hstack([circ4_y, circ3_y, circ2_y, circ1_y])
44 | ]).T * 3.0
45 | X = util_shuffle(X)
46 |
47 | # Add noise
48 | X = X + np.random.normal(scale=0.08, size=X.shape)
49 |
50 | return X.astype("float32")
51 |
52 | elif data == "moons":
53 | data = sklearn.datasets.make_moons(n_samples=batch_size, noise=0.1)[0]
54 | data = data.astype("float32")
55 | data = data * 2 + np.array([-1, -0.2])
56 | return data
57 |
58 | elif data == "8gaussians":
59 | scale = 4.
60 | centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)),
61 | (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2),
62 | 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))]
63 | centers = [(scale * x, scale * y) for x, y in centers]
64 |
65 | dataset = []
66 | for i in range(batch_size):
67 | point = np.random.randn(2) * 0.5
68 | idx = np.random.randint(8)
69 | center = centers[idx]
70 | point[0] += center[0]
71 | point[1] += center[1]
72 | dataset.append(point)
73 | dataset = np.array(dataset, dtype="float32")
74 | dataset /= 1.414
75 | return dataset
76 |
77 | elif data == "pinwheel":
78 | radial_std = 0.3
79 | tangential_std = 0.1
80 | num_classes = 5
81 | num_per_class = batch_size // 5
82 | rate = 0.25
83 | rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False)
84 |
85 | features = np.random.randn(num_classes*num_per_class, 2) \
86 | * np.array([radial_std, tangential_std])
87 | features[:, 0] += 1.
88 | labels = np.repeat(np.arange(num_classes), num_per_class)
89 |
90 | angles = rads[labels] + rate * np.exp(features[:, 0])
91 | rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)])
92 | rotations = np.reshape(rotations.T, (-1, 2, 2))
93 |
94 | return 2 * np.random.permutation(np.einsum("ti,tij->tj", features, rotations))
95 |
96 | elif data == "2spirals":
97 | n = np.sqrt(np.random.rand(batch_size // 2, 1)) * 540 * (2 * np.pi) / 360
98 | d1x = -np.cos(n) * n + np.random.rand(batch_size // 2, 1) * 0.5
99 | d1y = np.sin(n) * n + np.random.rand(batch_size // 2, 1) * 0.5
100 | x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3
101 | x += np.random.randn(*x.shape) * 0.1
102 | return x
103 |
104 | elif data == "checkerboard":
105 | x1 = np.random.rand(batch_size) * 4 - 2
106 | x2_ = np.random.rand(batch_size) - np.random.randint(0, 2, batch_size) * 2
107 | x2 = x2_ + (np.floor(x1) % 2)
108 | return np.concatenate([x1[:, None], x2[:, None]], 1) * 2
109 |
110 | elif data == "line":
111 | x = np.random.rand(batch_size) * 5 - 2.5
112 | y = x
113 | return np.stack((x, y), 1)
114 | elif data == "cos":
115 | x = np.random.rand(batch_size) * 5 - 2.5
116 | y = np.sin(x) * 2.5
117 | return np.stack((x, y), 1)
118 | else:
119 | return inf_train_gen("8gaussians", batch_size)
120 |
--------------------------------------------------------------------------------
/resflows/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | from numbers import Number
4 | import logging
5 | import torch
6 |
7 |
8 | def makedirs(dirname):
9 | if not os.path.exists(dirname):
10 | os.makedirs(dirname)
11 |
12 |
13 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
14 | logger = logging.getLogger()
15 | if debug:
16 | level = logging.DEBUG
17 | else:
18 | level = logging.INFO
19 | logger.setLevel(level)
20 | if saving:
21 | info_file_handler = logging.FileHandler(logpath, mode="a")
22 | info_file_handler.setLevel(level)
23 | logger.addHandler(info_file_handler)
24 | if displaying:
25 | console_handler = logging.StreamHandler()
26 | console_handler.setLevel(level)
27 | logger.addHandler(console_handler)
28 | logger.info(filepath)
29 | with open(filepath, "r") as f:
30 | logger.info(f.read())
31 |
32 | for f in package_files:
33 | logger.info(f)
34 | with open(f, "r") as package_f:
35 | logger.info(package_f.read())
36 |
37 | return logger
38 |
39 |
40 | class AverageMeter(object):
41 | """Computes and stores the average and current value"""
42 |
43 | def __init__(self):
44 | self.reset()
45 |
46 | def reset(self):
47 | self.val = 0
48 | self.avg = 0
49 | self.sum = 0
50 | self.count = 0
51 |
52 | def update(self, val, n=1):
53 | self.val = val
54 | self.sum += val * n
55 | self.count += n
56 | self.avg = self.sum / self.count
57 |
58 |
59 | class RunningAverageMeter(object):
60 | """Computes and stores the average and current value"""
61 |
62 | def __init__(self, momentum=0.99):
63 | self.momentum = momentum
64 | self.reset()
65 |
66 | def reset(self):
67 | self.val = None
68 | self.avg = 0
69 |
70 | def update(self, val):
71 | if self.val is None:
72 | self.avg = val
73 | else:
74 | self.avg = self.avg * self.momentum + val * (1 - self.momentum)
75 | self.val = val
76 |
77 |
78 | def inf_generator(iterable):
79 | """Allows training with DataLoaders in a single infinite loop:
80 | for i, (x, y) in enumerate(inf_generator(train_loader)):
81 | """
82 | iterator = iterable.__iter__()
83 | while True:
84 | try:
85 | yield iterator.__next__()
86 | except StopIteration:
87 | iterator = iterable.__iter__()
88 |
89 |
90 | def save_checkpoint(state, save, epoch, last_checkpoints=None, num_checkpoints=None):
91 | if not os.path.exists(save):
92 | os.makedirs(save)
93 | filename = os.path.join(save, 'checkpt-%04d.pth' % epoch)
94 | torch.save(state, filename)
95 |
96 | if last_checkpoints is not None and num_checkpoints is not None:
97 | last_checkpoints.append(epoch)
98 | if len(last_checkpoints) > num_checkpoints:
99 | rm_epoch = last_checkpoints.pop(0)
100 | os.remove(os.path.join(save, 'checkpt-%04d.pth' % rm_epoch))
101 |
102 |
103 | def isnan(tensor):
104 | return (tensor != tensor)
105 |
106 |
107 | def logsumexp(value, dim=None, keepdim=False):
108 | """Numerically stable implementation of the operation
109 | value.exp().sum(dim, keepdim).log()
110 | """
111 | if dim is not None:
112 | m, _ = torch.max(value, dim=dim, keepdim=True)
113 | value0 = value - m
114 | if keepdim is False:
115 | m = m.squeeze(dim)
116 | return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim))
117 | else:
118 | m = torch.max(value)
119 | sum_exp = torch.sum(torch.exp(value - m))
120 | if isinstance(sum_exp, Number):
121 | return m + math.log(sum_exp)
122 | else:
123 | return m + torch.log(sum_exp)
124 |
125 |
126 | class ExponentialMovingAverage(object):
127 |
128 | def __init__(self, module, decay=0.999):
129 | """Initializes the model when .apply() is called the first time.
130 | This is to take into account data-dependent initialization that occurs in the first iteration."""
131 | self.module = module
132 | self.decay = decay
133 | self.shadow_params = {}
134 | self.nparams = sum(p.numel() for p in module.parameters())
135 |
136 | def init(self):
137 | for name, param in self.module.named_parameters():
138 | self.shadow_params[name] = param.data.clone()
139 |
140 | def apply(self):
141 | if len(self.shadow_params) == 0:
142 | self.init()
143 | else:
144 | with torch.no_grad():
145 | for name, param in self.module.named_parameters():
146 | self.shadow_params[name] -= (1 - self.decay) * (self.shadow_params[name] - param.data)
147 |
148 | def set(self, other_ema):
149 | self.init()
150 | with torch.no_grad():
151 | for name, param in other_ema.shadow_params.items():
152 | self.shadow_params[name].copy_(param)
153 |
154 | def replace_with_ema(self):
155 | for name, param in self.module.named_parameters():
156 | param.data.copy_(self.shadow_params[name])
157 |
158 | def swap(self):
159 | for name, param in self.module.named_parameters():
160 | tmp = self.shadow_params[name].clone()
161 | self.shadow_params[name].copy_(param.data)
162 | param.data.copy_(tmp)
163 |
164 | def __repr__(self):
165 | return (
166 | '{}(decay={}, module={}, nparams={})'.format(
167 | self.__class__.__name__, self.decay, self.module.__class__.__name__, self.nparams
168 | )
169 | )
170 |
--------------------------------------------------------------------------------
/resflows/visualize_flow.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | matplotlib.use("Agg")
4 | import matplotlib.pyplot as plt
5 | import torch
6 |
7 | LOW = -4
8 | HIGH = 4
9 |
10 |
11 | def plt_potential_func(potential, ax, npts=100, title="$p(x)$"):
12 | """
13 | Args:
14 | potential: computes U(z_k) given z_k
15 | """
16 | xside = np.linspace(LOW, HIGH, npts)
17 | yside = np.linspace(LOW, HIGH, npts)
18 | xx, yy = np.meshgrid(xside, yside)
19 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])
20 |
21 | z = torch.Tensor(z)
22 | u = potential(z).cpu().numpy()
23 | p = np.exp(-u).reshape(npts, npts)
24 |
25 | plt.pcolormesh(xx, yy, p)
26 | ax.invert_yaxis()
27 | ax.get_xaxis().set_ticks([])
28 | ax.get_yaxis().set_ticks([])
29 | ax.set_title(title)
30 |
31 |
32 | def plt_flow(prior_logdensity, transform, ax, npts=100, title="$q(x)$", device="cpu"):
33 | """
34 | Args:
35 | transform: computes z_k and log(q_k) given z_0
36 | """
37 | side = np.linspace(LOW, HIGH, npts)
38 | xx, yy = np.meshgrid(side, side)
39 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])
40 |
41 | z = torch.tensor(z, requires_grad=True).type(torch.float32).to(device)
42 | logqz = prior_logdensity(z)
43 | logqz = torch.sum(logqz, dim=1)[:, None]
44 | z, logqz = transform(z, logqz)
45 | logqz = torch.sum(logqz, dim=1)[:, None]
46 |
47 | xx = z[:, 0].cpu().numpy().reshape(npts, npts)
48 | yy = z[:, 1].cpu().numpy().reshape(npts, npts)
49 | qz = np.exp(logqz.cpu().numpy()).reshape(npts, npts)
50 |
51 | plt.pcolormesh(xx, yy, qz)
52 | ax.set_xlim(LOW, HIGH)
53 | ax.set_ylim(LOW, HIGH)
54 | cmap = matplotlib.cm.get_cmap(None)
55 | ax.set_facecolor(cmap(0.))
56 | ax.invert_yaxis()
57 | ax.get_xaxis().set_ticks([])
58 | ax.get_yaxis().set_ticks([])
59 | ax.set_title(title)
60 |
61 |
62 | def plt_flow_density(prior_logdensity, inverse_transform, ax, npts=100, memory=100, title="$q(x)$", device="cpu"):
63 | side = np.linspace(LOW, HIGH, npts)
64 | xx, yy = np.meshgrid(side, side)
65 | x = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])
66 |
67 | x = torch.from_numpy(x).type(torch.float32).to(device)
68 | zeros = torch.zeros(x.shape[0], 1).to(x)
69 |
70 | z, delta_logp = [], []
71 | inds = torch.arange(0, x.shape[0]).to(torch.int64)
72 | for ii in torch.split(inds, int(memory**2)):
73 | z_, delta_logp_ = inverse_transform(x[ii], zeros[ii])
74 | z.append(z_)
75 | delta_logp.append(delta_logp_)
76 | z = torch.cat(z, 0)
77 | delta_logp = torch.cat(delta_logp, 0)
78 |
79 | logpz = prior_logdensity(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z)
80 | logpx = logpz - delta_logp
81 |
82 | px = np.exp(logpx.cpu().numpy()).reshape(npts, npts)
83 |
84 | ax.imshow(px, cmap='inferno')
85 | ax.get_xaxis().set_ticks([])
86 | ax.get_yaxis().set_ticks([])
87 | ax.set_title(title)
88 |
89 |
90 | def plt_flow_samples(prior_sample, transform, ax, npts=100, memory=100, title="$x ~ q(x)$", device="cpu"):
91 | z = prior_sample(npts * npts, 2).type(torch.float32).to(device)
92 | zk = []
93 | inds = torch.arange(0, z.shape[0]).to(torch.int64)
94 | for ii in torch.split(inds, int(memory**2)):
95 | zk.append(transform(z[ii]))
96 | zk = torch.cat(zk, 0).cpu().numpy()
97 | ax.hist2d(zk[:, 0], zk[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts, cmap='inferno')
98 | ax.invert_yaxis()
99 | ax.get_xaxis().set_ticks([])
100 | ax.get_yaxis().set_ticks([])
101 | ax.set_title(title)
102 |
103 |
104 | def plt_samples(samples, ax, npts=100, title="$x ~ p(x)$"):
105 | ax.hist2d(samples[:, 0], samples[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts, cmap='inferno')
106 | ax.invert_yaxis()
107 | ax.get_xaxis().set_ticks([])
108 | ax.get_yaxis().set_ticks([])
109 | ax.set_title(title)
110 |
111 |
112 | def visualize_transform(
113 | potential_or_samples, prior_sample, prior_density, transform=None, inverse_transform=None, samples=True, npts=100,
114 | memory=100, device="cpu"
115 | ):
116 | """Produces visualization for the model density and samples from the model."""
117 | plt.clf()
118 | ax = plt.subplot(1, 3, 1, aspect="equal")
119 | if samples:
120 | plt_samples(potential_or_samples, ax, npts=npts)
121 | else:
122 | plt_potential_func(potential_or_samples, ax, npts=npts)
123 |
124 | ax = plt.subplot(1, 3, 2, aspect="equal")
125 | if inverse_transform is None:
126 | plt_flow(prior_density, transform, ax, npts=npts, device=device)
127 | else:
128 | plt_flow_density(prior_density, inverse_transform, ax, npts=npts, memory=memory, device=device)
129 |
130 | ax = plt.subplot(1, 3, 3, aspect="equal")
131 | if transform is not None:
132 | plt_flow_samples(prior_sample, transform, ax, npts=npts, memory=memory, device=device)
133 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 |
3 | setup(
4 | name="resflows",
5 | version="0.0.1",
6 | py_modules=["resflows"],
7 | )
--------------------------------------------------------------------------------
/train_img.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import math
4 | import os
5 | import os.path
6 | import numpy as np
7 | from tqdm import tqdm
8 | import gc
9 |
10 | import torch
11 | import torchvision.transforms as transforms
12 | from torchvision.utils import save_image
13 | import torchvision.datasets as vdsets
14 |
15 | from resflows.resflow import ACT_FNS, ResidualFlow
16 | import resflows.datasets as datasets
17 | import resflows.optimizers as optim
18 | import resflows.utils as utils
19 | import resflows.layers as layers
20 | import resflows.layers.base as base_layers
21 | from resflows.lr_scheduler import CosineAnnealingWarmRestarts
22 |
23 | # Arguments
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument(
26 | '--data', type=str, default='cifar10', choices=[
27 | 'mnist',
28 | 'cifar10',
29 | 'svhn',
30 | 'celebahq',
31 | 'celeba_5bit',
32 | 'imagenet32',
33 | 'imagenet64',
34 | ]
35 | )
36 | parser.add_argument('--dataroot', type=str, default='data')
37 | parser.add_argument('--imagesize', type=int, default=32)
38 | parser.add_argument('--nbits', type=int, default=8) # Only used for celebahq.
39 |
40 | parser.add_argument('--block', type=str, choices=['resblock', 'coupling'], default='resblock')
41 |
42 | parser.add_argument('--coeff', type=float, default=0.98)
43 | parser.add_argument('--vnorms', type=str, default='2222')
44 | parser.add_argument('--n-lipschitz-iters', type=int, default=None)
45 | parser.add_argument('--sn-tol', type=float, default=1e-3)
46 | parser.add_argument('--learn-p', type=eval, choices=[True, False], default=False)
47 |
48 | parser.add_argument('--n-power-series', type=int, default=None)
49 | parser.add_argument('--factor-out', type=eval, choices=[True, False], default=False)
50 | parser.add_argument('--n-dist', choices=['geometric', 'poisson'], default='poisson')
51 | parser.add_argument('--n-samples', type=int, default=1)
52 | parser.add_argument('--n-exact-terms', type=int, default=2)
53 | parser.add_argument('--var-reduc-lr', type=float, default=0)
54 | parser.add_argument('--neumann-grad', type=eval, choices=[True, False], default=True)
55 | parser.add_argument('--mem-eff', type=eval, choices=[True, False], default=True)
56 |
57 | parser.add_argument('--act', type=str, choices=ACT_FNS.keys(), default='swish')
58 | parser.add_argument('--idim', type=int, default=512)
59 | parser.add_argument('--nblocks', type=str, default='16-16-16')
60 | parser.add_argument('--squeeze-first', type=eval, default=False, choices=[True, False])
61 | parser.add_argument('--actnorm', type=eval, default=True, choices=[True, False])
62 | parser.add_argument('--fc-actnorm', type=eval, default=False, choices=[True, False])
63 | parser.add_argument('--batchnorm', type=eval, default=False, choices=[True, False])
64 | parser.add_argument('--dropout', type=float, default=0.)
65 | parser.add_argument('--fc', type=eval, default=False, choices=[True, False])
66 | parser.add_argument('--kernels', type=str, default='3-1-3')
67 | parser.add_argument('--logit-transform', type=eval, choices=[True, False], default=True)
68 | parser.add_argument('--add-noise', type=eval, choices=[True, False], default=True)
69 | parser.add_argument('--quadratic', type=eval, choices=[True, False], default=False)
70 | parser.add_argument('--fc-end', type=eval, choices=[True, False], default=True)
71 | parser.add_argument('--fc-idim', type=int, default=128)
72 | parser.add_argument('--preact', type=eval, choices=[True, False], default=True)
73 | parser.add_argument('--padding', type=int, default=0)
74 | parser.add_argument('--first-resblock', type=eval, choices=[True, False], default=True)
75 | parser.add_argument('--cdim', type=int, default=256)
76 |
77 | parser.add_argument('--optimizer', type=str, choices=['adam', 'adamax', 'rmsprop', 'sgd'], default='adam')
78 | parser.add_argument('--scheduler', type=eval, choices=[True, False], default=False)
79 | parser.add_argument('--nepochs', help='Number of epochs for training', type=int, default=1000)
80 | parser.add_argument('--batchsize', help='Minibatch size', type=int, default=64)
81 | parser.add_argument('--lr', help='Learning rate', type=float, default=1e-3)
82 | parser.add_argument('--wd', help='Weight decay', type=float, default=0)
83 | parser.add_argument('--warmup-iters', type=int, default=1000)
84 | parser.add_argument('--annealing-iters', type=int, default=0)
85 | parser.add_argument('--save', help='directory to save results', type=str, default='experiment1')
86 | parser.add_argument('--val-batchsize', help='minibatch size', type=int, default=200)
87 | parser.add_argument('--seed', type=int, default=None)
88 | parser.add_argument('--ema-val', type=eval, choices=[True, False], default=True)
89 | parser.add_argument('--update-freq', type=int, default=1)
90 |
91 | parser.add_argument('--task', type=str, choices=['density', 'classification', 'hybrid'], default='density')
92 | parser.add_argument('--scale-dim', type=eval, choices=[True, False], default=False)
93 | parser.add_argument('--rcrop-pad-mode', type=str, choices=['constant', 'reflect'], default='reflect')
94 | parser.add_argument('--padding-dist', type=str, choices=['uniform', 'gaussian'], default='uniform')
95 |
96 | parser.add_argument('--resume', type=str, default=None)
97 | parser.add_argument('--begin-epoch', type=int, default=0)
98 |
99 | parser.add_argument('--nworkers', type=int, default=4)
100 | parser.add_argument('--print-freq', help='Print progress every so iterations', type=int, default=20)
101 | parser.add_argument('--vis-freq', help='Visualize progress every so iterations', type=int, default=500)
102 | args = parser.parse_args()
103 |
104 | # Random seed
105 | if args.seed is None:
106 | args.seed = np.random.randint(100000)
107 |
108 | # logger
109 | utils.makedirs(args.save)
110 | logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))
111 | logger.info(args)
112 |
113 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
114 |
115 | if device.type == 'cuda':
116 | logger.info('Found {} CUDA devices.'.format(torch.cuda.device_count()))
117 | for i in range(torch.cuda.device_count()):
118 | props = torch.cuda.get_device_properties(i)
119 | logger.info('{} \t Memory: {:.2f}GB'.format(props.name, props.total_memory / (1024**3)))
120 | else:
121 | logger.info('WARNING: Using device {}'.format(device))
122 |
123 | np.random.seed(args.seed)
124 | torch.manual_seed(args.seed)
125 | if device.type == 'cuda':
126 | torch.cuda.manual_seed(args.seed)
127 |
128 |
129 | def geometric_logprob(ns, p):
130 | return torch.log(1 - p + 1e-10) * (ns - 1) + torch.log(p + 1e-10)
131 |
132 |
133 | def standard_normal_sample(size):
134 | return torch.randn(size)
135 |
136 |
137 | def standard_normal_logprob(z):
138 | logZ = -0.5 * math.log(2 * math.pi)
139 | return logZ - z.pow(2) / 2
140 |
141 |
142 | def normal_logprob(z, mean, log_std):
143 | mean = mean + torch.tensor(0.)
144 | log_std = log_std + torch.tensor(0.)
145 | c = torch.tensor([math.log(2 * math.pi)]).to(z)
146 | inv_sigma = torch.exp(-log_std)
147 | tmp = (z - mean) * inv_sigma
148 | return -0.5 * (tmp * tmp + 2 * log_std + c)
149 |
150 |
151 | def count_parameters(model):
152 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
153 |
154 |
155 | def reduce_bits(x):
156 | if args.nbits < 8:
157 | x = x * 255
158 | x = torch.floor(x / 2**(8 - args.nbits))
159 | x = x / 2**args.nbits
160 | return x
161 |
162 |
163 | def add_noise(x, nvals=256):
164 | """
165 | [0, 1] -> [0, nvals] -> add noise -> [0, 1]
166 | """
167 | if args.add_noise:
168 | noise = x.new().resize_as_(x).uniform_()
169 | x = x * (nvals - 1) + noise
170 | x = x / nvals
171 | return x
172 |
173 |
174 | def update_lr(optimizer, itr):
175 | iter_frac = min(float(itr + 1) / max(args.warmup_iters, 1), 1.0)
176 | lr = args.lr * iter_frac
177 | for param_group in optimizer.param_groups:
178 | param_group["lr"] = lr
179 |
180 |
181 | def add_padding(x, nvals=256):
182 | # Theoretically, padding should've been added before the add_noise preprocessing.
183 | # nvals takes into account the preprocessing before padding is added.
184 | if args.padding > 0:
185 | if args.padding_dist == 'uniform':
186 | u = x.new_empty(x.shape[0], args.padding, x.shape[2], x.shape[3]).uniform_()
187 | logpu = torch.zeros_like(u).sum([1, 2, 3]).view(-1, 1)
188 | return torch.cat([x, u / nvals], dim=1), logpu
189 | elif args.padding_dist == 'gaussian':
190 | u = x.new_empty(x.shape[0], args.padding, x.shape[2], x.shape[3]).normal_(nvals / 2, nvals / 8)
191 | logpu = normal_logprob(u, nvals / 2, math.log(nvals / 8)).sum([1, 2, 3]).view(-1, 1)
192 | return torch.cat([x, u / nvals], dim=1), logpu
193 | else:
194 | raise ValueError()
195 | else:
196 | return x, torch.zeros(x.shape[0], 1).to(x)
197 |
198 |
199 | def remove_padding(x):
200 | if args.padding > 0:
201 | return x[:, :im_dim, :, :]
202 | else:
203 | return x
204 |
205 |
206 | logger.info('Loading dataset {}'.format(args.data))
207 | # Dataset and hyperparameters
208 | if args.data == 'cifar10':
209 | im_dim = 3
210 | n_classes = 10
211 | if args.task in ['classification', 'hybrid']:
212 |
213 | # Classification-specific preprocessing.
214 | transform_train = transforms.Compose([
215 | transforms.Resize(args.imagesize),
216 | transforms.RandomCrop(32, padding=4, padding_mode=args.rcrop_pad_mode),
217 | transforms.RandomHorizontalFlip(),
218 | transforms.ToTensor(),
219 | add_noise,
220 | ])
221 |
222 | transform_test = transforms.Compose([
223 | transforms.Resize(args.imagesize),
224 | transforms.ToTensor(),
225 | add_noise,
226 | ])
227 |
228 | # Remove the logit transform.
229 | init_layer = layers.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
230 | else:
231 | transform_train = transforms.Compose([
232 | transforms.Resize(args.imagesize),
233 | transforms.RandomHorizontalFlip(),
234 | transforms.ToTensor(),
235 | add_noise,
236 | ])
237 | transform_test = transforms.Compose([
238 | transforms.Resize(args.imagesize),
239 | transforms.ToTensor(),
240 | add_noise,
241 | ])
242 | init_layer = layers.LogitTransform(0.05) if args.logit_transform else None
243 | train_loader = torch.utils.data.DataLoader(
244 | datasets.CIFAR10(args.dataroot, train=True, transform=transform_train),
245 | batch_size=args.batchsize,
246 | shuffle=True,
247 | num_workers=args.nworkers,
248 | )
249 | test_loader = torch.utils.data.DataLoader(
250 | datasets.CIFAR10(args.dataroot, train=False, transform=transform_test),
251 | batch_size=args.val_batchsize,
252 | shuffle=False,
253 | num_workers=args.nworkers,
254 | )
255 | elif args.data == 'mnist':
256 | im_dim = 1
257 | init_layer = layers.LogitTransform(1e-6) if args.logit_transform else None
258 | n_classes = 10
259 | train_loader = torch.utils.data.DataLoader(
260 | datasets.MNIST(
261 | args.dataroot, train=True, transform=transforms.Compose([
262 | transforms.Resize(args.imagesize),
263 | transforms.ToTensor(),
264 | add_noise,
265 | ])
266 | ),
267 | batch_size=args.batchsize,
268 | shuffle=True,
269 | num_workers=args.nworkers,
270 | )
271 | test_loader = torch.utils.data.DataLoader(
272 | datasets.MNIST(
273 | args.dataroot, train=False, transform=transforms.Compose([
274 | transforms.Resize(args.imagesize),
275 | transforms.ToTensor(),
276 | add_noise,
277 | ])
278 | ),
279 | batch_size=args.val_batchsize,
280 | shuffle=False,
281 | num_workers=args.nworkers,
282 | )
283 | elif args.data == 'svhn':
284 | im_dim = 3
285 | init_layer = layers.LogitTransform(0.05) if args.logit_transform else None
286 | n_classes = 10
287 | train_loader = torch.utils.data.DataLoader(
288 | vdsets.SVHN(
289 | args.dataroot, split='train', download=True, transform=transforms.Compose([
290 | transforms.Resize(args.imagesize),
291 | transforms.RandomCrop(32, padding=4, padding_mode=args.rcrop_pad_mode),
292 | transforms.ToTensor(),
293 | add_noise,
294 | ])
295 | ),
296 | batch_size=args.batchsize,
297 | shuffle=True,
298 | num_workers=args.nworkers,
299 | )
300 | test_loader = torch.utils.data.DataLoader(
301 | vdsets.SVHN(
302 | args.dataroot, split='test', download=True, transform=transforms.Compose([
303 | transforms.Resize(args.imagesize),
304 | transforms.ToTensor(),
305 | add_noise,
306 | ])
307 | ),
308 | batch_size=args.val_batchsize,
309 | shuffle=False,
310 | num_workers=args.nworkers,
311 | )
312 | elif args.data == 'celebahq':
313 | im_dim = 3
314 | init_layer = layers.LogitTransform(0.05) if args.logit_transform else None
315 | if args.imagesize != 256:
316 | logger.info('Changing image size to 256.')
317 | args.imagesize = 256
318 | train_loader = torch.utils.data.DataLoader(
319 | datasets.CelebAHQ(
320 | train=True, transform=transforms.Compose([
321 | transforms.ToPILImage(),
322 | transforms.RandomHorizontalFlip(),
323 | transforms.ToTensor(),
324 | reduce_bits,
325 | lambda x: add_noise(x, nvals=2**args.nbits),
326 | ])
327 | ), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers
328 | )
329 | test_loader = torch.utils.data.DataLoader(
330 | datasets.CelebAHQ(
331 | train=False, transform=transforms.Compose([
332 | reduce_bits,
333 | lambda x: add_noise(x, nvals=2**args.nbits),
334 | ])
335 | ), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers
336 | )
337 | elif args.data == 'celeba_5bit':
338 | im_dim = 3
339 | init_layer = layers.LogitTransform(0.05) if args.logit_transform else None
340 | if args.imagesize != 64:
341 | logger.info('Changing image size to 64.')
342 | args.imagesize = 64
343 | train_loader = torch.utils.data.DataLoader(
344 | datasets.CelebA5bit(
345 | train=True, transform=transforms.Compose([
346 | transforms.ToPILImage(),
347 | transforms.RandomHorizontalFlip(),
348 | transforms.ToTensor(),
349 | lambda x: add_noise(x, nvals=32),
350 | ])
351 | ), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers
352 | )
353 | test_loader = torch.utils.data.DataLoader(
354 | datasets.CelebA5bit(train=False, transform=transforms.Compose([
355 | lambda x: add_noise(x, nvals=32),
356 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers
357 | )
358 | elif args.data == 'imagenet32':
359 | im_dim = 3
360 | init_layer = layers.LogitTransform(0.05) if args.logit_transform else None
361 | if args.imagesize != 32:
362 | logger.info('Changing image size to 32.')
363 | args.imagesize = 32
364 | train_loader = torch.utils.data.DataLoader(
365 | datasets.Imagenet32(train=True, transform=transforms.Compose([
366 | add_noise,
367 | ])), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers
368 | )
369 | test_loader = torch.utils.data.DataLoader(
370 | datasets.Imagenet32(train=False, transform=transforms.Compose([
371 | add_noise,
372 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers
373 | )
374 | elif args.data == 'imagenet64':
375 | im_dim = 3
376 | init_layer = layers.LogitTransform(0.05) if args.logit_transform else None
377 | if args.imagesize != 64:
378 | logger.info('Changing image size to 64.')
379 | args.imagesize = 64
380 | train_loader = torch.utils.data.DataLoader(
381 | datasets.Imagenet64(train=True, transform=transforms.Compose([
382 | add_noise,
383 | ])), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers
384 | )
385 | test_loader = torch.utils.data.DataLoader(
386 | datasets.Imagenet64(train=False, transform=transforms.Compose([
387 | add_noise,
388 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers
389 | )
390 |
391 | if args.task in ['classification', 'hybrid']:
392 | try:
393 | n_classes
394 | except NameError:
395 | raise ValueError('Cannot perform classification with {}'.format(args.data))
396 | else:
397 | n_classes = 1
398 |
399 | logger.info('Dataset loaded.')
400 | logger.info('Creating model.')
401 |
402 | input_size = (args.batchsize, im_dim + args.padding, args.imagesize, args.imagesize)
403 | dataset_size = len(train_loader.dataset)
404 |
405 | if args.squeeze_first:
406 | input_size = (input_size[0], input_size[1] * 4, input_size[2] // 2, input_size[3] // 2)
407 | squeeze_layer = layers.SqueezeLayer(2)
408 |
409 | # Model
410 | model = ResidualFlow(
411 | input_size,
412 | n_blocks=list(map(int, args.nblocks.split('-'))),
413 | intermediate_dim=args.idim,
414 | factor_out=args.factor_out,
415 | quadratic=args.quadratic,
416 | init_layer=init_layer,
417 | actnorm=args.actnorm,
418 | fc_actnorm=args.fc_actnorm,
419 | batchnorm=args.batchnorm,
420 | dropout=args.dropout,
421 | fc=args.fc,
422 | coeff=args.coeff,
423 | vnorms=args.vnorms,
424 | n_lipschitz_iters=args.n_lipschitz_iters,
425 | sn_atol=args.sn_tol,
426 | sn_rtol=args.sn_tol,
427 | n_power_series=args.n_power_series,
428 | n_dist=args.n_dist,
429 | n_samples=args.n_samples,
430 | kernels=args.kernels,
431 | activation_fn=args.act,
432 | fc_end=args.fc_end,
433 | fc_idim=args.fc_idim,
434 | n_exact_terms=args.n_exact_terms,
435 | preact=args.preact,
436 | neumann_grad=args.neumann_grad,
437 | grad_in_forward=args.mem_eff,
438 | first_resblock=args.first_resblock,
439 | learn_p=args.learn_p,
440 | classification=args.task in ['classification', 'hybrid'],
441 | classification_hdim=args.cdim,
442 | n_classes=n_classes,
443 | block_type=args.block,
444 | )
445 |
446 | model.to(device)
447 | ema = utils.ExponentialMovingAverage(model)
448 |
449 |
450 | def parallelize(model):
451 | return torch.nn.DataParallel(model)
452 |
453 |
454 | logger.info(model)
455 | logger.info('EMA: {}'.format(ema))
456 |
457 |
458 | # Optimization
459 | def tensor_in(t, a):
460 | for a_ in a:
461 | if t is a_:
462 | return True
463 | return False
464 |
465 |
466 | scheduler = None
467 |
468 | if args.optimizer == 'adam':
469 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.wd)
470 | if args.scheduler: scheduler = CosineAnnealingWarmRestarts(optimizer, 20, T_mult=2, last_epoch=args.begin_epoch - 1)
471 | elif args.optimizer == 'adamax':
472 | optimizer = optim.Adamax(model.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.wd)
473 | elif args.optimizer == 'rmsprop':
474 | optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.wd)
475 | elif args.optimizer == 'sgd':
476 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.wd)
477 | if args.scheduler:
478 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
479 | optimizer, milestones=[60, 120, 160], gamma=0.2, last_epoch=args.begin_epoch - 1
480 | )
481 | else:
482 | raise ValueError('Unknown optimizer {}'.format(args.optimizer))
483 |
484 | best_test_bpd = math.inf
485 | if (args.resume is not None):
486 | logger.info('Resuming model from {}'.format(args.resume))
487 | with torch.no_grad():
488 | x = torch.rand(1, *input_size[1:]).to(device)
489 | model(x)
490 | checkpt = torch.load(args.resume)
491 | sd = {k: v for k, v in checkpt['state_dict'].items() if 'last_n_samples' not in k}
492 | state = model.state_dict()
493 | state.update(sd)
494 | model.load_state_dict(state, strict=True)
495 | ema.set(checkpt['ema'])
496 | if 'optimizer_state_dict' in checkpt:
497 | optimizer.load_state_dict(checkpt['optimizer_state_dict'])
498 | # Manually move optimizer state to GPU
499 | for state in optimizer.state.values():
500 | for k, v in state.items():
501 | if torch.is_tensor(v):
502 | state[k] = v.to(device)
503 | del checkpt
504 | del state
505 |
506 | logger.info(optimizer)
507 |
508 | fixed_z = standard_normal_sample([min(32, args.batchsize),
509 | (im_dim + args.padding) * args.imagesize * args.imagesize]).to(device)
510 |
511 | criterion = torch.nn.CrossEntropyLoss()
512 |
513 |
514 | def compute_loss(x, model, beta=1.0):
515 | bits_per_dim, logits_tensor = torch.zeros(1).to(x), torch.zeros(n_classes).to(x)
516 | logpz, delta_logp = torch.zeros(1).to(x), torch.zeros(1).to(x)
517 |
518 | if args.data == 'celeba_5bit':
519 | nvals = 32
520 | elif args.data == 'celebahq':
521 | nvals = 2**args.nbits
522 | else:
523 | nvals = 256
524 |
525 | x, logpu = add_padding(x, nvals)
526 |
527 | if args.squeeze_first:
528 | x = squeeze_layer(x)
529 |
530 | if args.task == 'hybrid':
531 | z_logp, logits_tensor = model(x.view(-1, *input_size[1:]), 0, classify=True)
532 | z, delta_logp = z_logp
533 | elif args.task == 'density':
534 | z, delta_logp = model(x.view(-1, *input_size[1:]), 0)
535 | elif args.task == 'classification':
536 | z, logits_tensor = model(x.view(-1, *input_size[1:]), classify=True)
537 |
538 | if args.task in ['density', 'hybrid']:
539 | # log p(z)
540 | logpz = standard_normal_logprob(z).view(z.size(0), -1).sum(1, keepdim=True)
541 |
542 | # log p(x)
543 | logpx = logpz - beta * delta_logp - np.log(nvals) * (
544 | args.imagesize * args.imagesize * (im_dim + args.padding)
545 | ) - logpu
546 | bits_per_dim = -torch.mean(logpx) / (args.imagesize * args.imagesize * im_dim) / np.log(2)
547 |
548 | logpz = torch.mean(logpz).detach()
549 | delta_logp = torch.mean(-delta_logp).detach()
550 |
551 | return bits_per_dim, logits_tensor, logpz, delta_logp
552 |
553 |
554 | def estimator_moments(model, baseline=0):
555 | avg_first_moment = 0.
556 | avg_second_moment = 0.
557 | for m in model.modules():
558 | if isinstance(m, layers.iResBlock):
559 | avg_first_moment += m.last_firmom.item()
560 | avg_second_moment += m.last_secmom.item()
561 | return avg_first_moment, avg_second_moment
562 |
563 |
564 | def compute_p_grads(model):
565 | scales = 0.
566 | nlayers = 0
567 | for m in model.modules():
568 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
569 | scales = scales + m.compute_one_iter()
570 | nlayers += 1
571 | scales.mul(1 / nlayers).backward()
572 | for m in model.modules():
573 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
574 | if m.domain.grad is not None and torch.isnan(m.domain.grad):
575 | m.domain.grad = None
576 |
577 |
578 | batch_time = utils.RunningAverageMeter(0.97)
579 | bpd_meter = utils.RunningAverageMeter(0.97)
580 | logpz_meter = utils.RunningAverageMeter(0.97)
581 | deltalogp_meter = utils.RunningAverageMeter(0.97)
582 | firmom_meter = utils.RunningAverageMeter(0.97)
583 | secmom_meter = utils.RunningAverageMeter(0.97)
584 | gnorm_meter = utils.RunningAverageMeter(0.97)
585 | ce_meter = utils.RunningAverageMeter(0.97)
586 |
587 |
588 | def train(epoch, model):
589 |
590 | model = parallelize(model)
591 | model.train()
592 |
593 | total = 0
594 | correct = 0
595 |
596 | end = time.time()
597 |
598 | for i, (x, y) in enumerate(train_loader):
599 |
600 | global_itr = epoch * len(train_loader) + i
601 | update_lr(optimizer, global_itr)
602 |
603 | # Training procedure:
604 | # for each sample x:
605 | # compute z = f(x)
606 | # maximize log p(x) = log p(z) - log |det df/dx|
607 |
608 | x = x.to(device)
609 |
610 | beta = beta = min(1, global_itr / args.annealing_iters) if args.annealing_iters > 0 else 1.
611 | bpd, logits, logpz, neg_delta_logp = compute_loss(x, model, beta=beta)
612 |
613 | if args.task in ['density', 'hybrid']:
614 | firmom, secmom = estimator_moments(model)
615 |
616 | bpd_meter.update(bpd.item())
617 | logpz_meter.update(logpz.item())
618 | deltalogp_meter.update(neg_delta_logp.item())
619 | firmom_meter.update(firmom)
620 | secmom_meter.update(secmom)
621 |
622 | if args.task in ['classification', 'hybrid']:
623 | y = y.to(device)
624 | crossent = criterion(logits, y)
625 | ce_meter.update(crossent.item())
626 |
627 | # Compute accuracy.
628 | _, predicted = logits.max(1)
629 | total += y.size(0)
630 | correct += predicted.eq(y).sum().item()
631 |
632 | # compute gradient and do SGD step
633 | if args.task == 'density':
634 | loss = bpd
635 | elif args.task == 'classification':
636 | loss = crossent
637 | else:
638 | if not args.scale_dim: bpd = bpd * (args.imagesize * args.imagesize * im_dim)
639 | loss = bpd + crossent / np.log(2) # Change cross entropy from nats to bits.
640 | loss.backward()
641 |
642 | if global_itr % args.update_freq == args.update_freq - 1:
643 |
644 | if args.update_freq > 1:
645 | with torch.no_grad():
646 | for p in model.parameters():
647 | if p.grad is not None:
648 | p.grad /= args.update_freq
649 |
650 | grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), 1.)
651 | if args.learn_p: compute_p_grads(model)
652 |
653 | optimizer.step()
654 | optimizer.zero_grad()
655 | update_lipschitz(model)
656 | ema.apply()
657 |
658 | gnorm_meter.update(grad_norm)
659 |
660 | # measure elapsed time
661 | batch_time.update(time.time() - end)
662 | end = time.time()
663 |
664 | if i % args.print_freq == 0:
665 | s = (
666 | 'Epoch: [{0}][{1}/{2}] | Time {batch_time.val:.3f} | '
667 | 'GradNorm {gnorm_meter.avg:.2f}'.format(
668 | epoch, i, len(train_loader), batch_time=batch_time, gnorm_meter=gnorm_meter
669 | )
670 | )
671 |
672 | if args.task in ['density', 'hybrid']:
673 | s += (
674 | ' | Bits/dim {bpd_meter.val:.4f}({bpd_meter.avg:.4f}) | '
675 | 'Logpz {logpz_meter.avg:.0f} | '
676 | '-DeltaLogp {deltalogp_meter.avg:.0f} | '
677 | 'EstMoment ({firmom_meter.avg:.0f},{secmom_meter.avg:.0f})'.format(
678 | bpd_meter=bpd_meter, logpz_meter=logpz_meter, deltalogp_meter=deltalogp_meter,
679 | firmom_meter=firmom_meter, secmom_meter=secmom_meter
680 | )
681 | )
682 |
683 | if args.task in ['classification', 'hybrid']:
684 | s += ' | CE {ce_meter.avg:.4f} | Acc {0:.4f}'.format(100 * correct / total, ce_meter=ce_meter)
685 |
686 | logger.info(s)
687 | if i % args.vis_freq == 0:
688 | visualize(epoch, model, i, x)
689 |
690 | del x
691 | torch.cuda.empty_cache()
692 | gc.collect()
693 |
694 |
695 | def validate(epoch, model, ema=None):
696 | """
697 | Evaluates the cross entropy between p_data and p_model.
698 | """
699 | bpd_meter = utils.AverageMeter()
700 | ce_meter = utils.AverageMeter()
701 |
702 | if ema is not None:
703 | ema.swap()
704 |
705 | update_lipschitz(model)
706 |
707 | model = parallelize(model)
708 | model.eval()
709 |
710 | correct = 0
711 | total = 0
712 |
713 | start = time.time()
714 | with torch.no_grad():
715 | for i, (x, y) in enumerate(tqdm(test_loader)):
716 | x = x.to(device)
717 | bpd, logits, _, _ = compute_loss(x, model)
718 | bpd_meter.update(bpd.item(), x.size(0))
719 |
720 | if args.task in ['classification', 'hybrid']:
721 | y = y.to(device)
722 | loss = criterion(logits, y)
723 | ce_meter.update(loss.item(), x.size(0))
724 | _, predicted = logits.max(1)
725 | total += y.size(0)
726 | correct += predicted.eq(y).sum().item()
727 | val_time = time.time() - start
728 |
729 | if ema is not None:
730 | ema.swap()
731 | s = 'Epoch: [{0}]\tTime {1:.2f} | Test bits/dim {bpd_meter.avg:.4f}'.format(epoch, val_time, bpd_meter=bpd_meter)
732 | if args.task in ['classification', 'hybrid']:
733 | s += ' | CE {:.4f} | Acc {:.2f}'.format(ce_meter.avg, 100 * correct / total)
734 | logger.info(s)
735 | return bpd_meter.avg
736 |
737 |
738 | def visualize(epoch, model, itr, real_imgs):
739 | model.eval()
740 | utils.makedirs(os.path.join(args.save, 'imgs'))
741 | real_imgs = real_imgs[:32]
742 | _real_imgs = real_imgs
743 |
744 | if args.data == 'celeba_5bit':
745 | nvals = 32
746 | elif args.data == 'celebahq':
747 | nvals = 2**args.nbits
748 | else:
749 | nvals = 256
750 |
751 | with torch.no_grad():
752 | # reconstructed real images
753 | real_imgs, _ = add_padding(real_imgs, nvals)
754 | if args.squeeze_first: real_imgs = squeeze_layer(real_imgs)
755 | recon_imgs = model(model(real_imgs.view(-1, *input_size[1:])), inverse=True).view(-1, *input_size[1:])
756 | if args.squeeze_first: recon_imgs = squeeze_layer.inverse(recon_imgs)
757 | recon_imgs = remove_padding(recon_imgs)
758 |
759 | # random samples
760 | fake_imgs = model(fixed_z, inverse=True).view(-1, *input_size[1:])
761 | if args.squeeze_first: fake_imgs = squeeze_layer.inverse(fake_imgs)
762 | fake_imgs = remove_padding(fake_imgs)
763 |
764 | fake_imgs = fake_imgs.view(-1, im_dim, args.imagesize, args.imagesize)
765 | recon_imgs = recon_imgs.view(-1, im_dim, args.imagesize, args.imagesize)
766 | imgs = torch.cat([_real_imgs, fake_imgs, recon_imgs], 0)
767 |
768 | filename = os.path.join(args.save, 'imgs', 'e{:03d}_i{:06d}.png'.format(epoch, itr))
769 | save_image(imgs.cpu().float(), filename, nrow=16, padding=2)
770 | model.train()
771 |
772 |
773 | def get_lipschitz_constants(model):
774 | lipschitz_constants = []
775 | for m in model.modules():
776 | if isinstance(m, base_layers.SpectralNormConv2d) or isinstance(m, base_layers.SpectralNormLinear):
777 | lipschitz_constants.append(m.scale)
778 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
779 | lipschitz_constants.append(m.scale)
780 | if isinstance(m, base_layers.LopConv2d) or isinstance(m, base_layers.LopLinear):
781 | lipschitz_constants.append(m.scale)
782 | return lipschitz_constants
783 |
784 |
785 | def update_lipschitz(model):
786 | with torch.no_grad():
787 | for m in model.modules():
788 | if isinstance(m, base_layers.SpectralNormConv2d) or isinstance(m, base_layers.SpectralNormLinear):
789 | m.compute_weight(update=True)
790 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
791 | m.compute_weight(update=True)
792 |
793 |
794 | def get_ords(model):
795 | ords = []
796 | for m in model.modules():
797 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
798 | domain, codomain = m.compute_domain_codomain()
799 | if torch.is_tensor(domain):
800 | domain = domain.item()
801 | if torch.is_tensor(codomain):
802 | codomain = codomain.item()
803 | ords.append(domain)
804 | ords.append(codomain)
805 | return ords
806 |
807 |
808 | def pretty_repr(a):
809 | return '[[' + ','.join(list(map(lambda i: f'{i:.2f}', a))) + ']]'
810 |
811 |
812 | def main():
813 | global best_test_bpd
814 |
815 | last_checkpoints = []
816 | lipschitz_constants = []
817 | ords = []
818 |
819 | # if args.resume:
820 | # validate(args.begin_epoch - 1, model, ema)
821 | for epoch in range(args.begin_epoch, args.nepochs):
822 |
823 | logger.info('Current LR {}'.format(optimizer.param_groups[0]['lr']))
824 |
825 | train(epoch, model)
826 | lipschitz_constants.append(get_lipschitz_constants(model))
827 | ords.append(get_ords(model))
828 | logger.info('Lipsh: {}'.format(pretty_repr(lipschitz_constants[-1])))
829 | logger.info('Order: {}'.format(pretty_repr(ords[-1])))
830 |
831 | if args.ema_val:
832 | test_bpd = validate(epoch, model, ema)
833 | else:
834 | test_bpd = validate(epoch, model)
835 |
836 | if args.scheduler and scheduler is not None:
837 | scheduler.step()
838 |
839 | if test_bpd < best_test_bpd:
840 | best_test_bpd = test_bpd
841 | utils.save_checkpoint({
842 | 'state_dict': model.state_dict(),
843 | 'optimizer_state_dict': optimizer.state_dict(),
844 | 'args': args,
845 | 'ema': ema,
846 | 'test_bpd': test_bpd,
847 | }, os.path.join(args.save, 'models'), epoch, last_checkpoints, num_checkpoints=5)
848 |
849 | torch.save({
850 | 'state_dict': model.state_dict(),
851 | 'optimizer_state_dict': optimizer.state_dict(),
852 | 'args': args,
853 | 'ema': ema,
854 | 'test_bpd': test_bpd,
855 | }, os.path.join(args.save, 'models', 'most_recent.pth'))
856 |
857 |
858 | if __name__ == '__main__':
859 | main()
860 |
--------------------------------------------------------------------------------
/train_toy.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg')
3 | import matplotlib.pyplot as plt
4 |
5 | import argparse
6 | import os
7 | import time
8 | import math
9 | import numpy as np
10 |
11 | import torch
12 |
13 | import resflows.optimizers as optim
14 | import resflows.layers.base as base_layers
15 | import resflows.layers as layers
16 | import resflows.toy_data as toy_data
17 | import resflows.utils as utils
18 | from resflows.visualize_flow import visualize_transform
19 |
20 | ACTIVATION_FNS = {
21 | 'relu': torch.nn.ReLU,
22 | 'tanh': torch.nn.Tanh,
23 | 'elu': torch.nn.ELU,
24 | 'selu': torch.nn.SELU,
25 | 'fullsort': base_layers.FullSort,
26 | 'maxmin': base_layers.MaxMin,
27 | 'swish': base_layers.Swish,
28 | 'lcube': base_layers.LipschitzCube,
29 | }
30 |
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument(
33 | '--data', choices=['swissroll', '8gaussians', 'pinwheel', 'circles', 'moons', '2spirals', 'checkerboard', 'rings'],
34 | type=str, default='pinwheel'
35 | )
36 | parser.add_argument('--arch', choices=['iresnet', 'realnvp'], default='iresnet')
37 | parser.add_argument('--coeff', type=float, default=0.9)
38 | parser.add_argument('--vnorms', type=str, default='222222')
39 | parser.add_argument('--n-lipschitz-iters', type=int, default=5)
40 | parser.add_argument('--atol', type=float, default=None)
41 | parser.add_argument('--rtol', type=float, default=None)
42 | parser.add_argument('--learn-p', type=eval, choices=[True, False], default=False)
43 | parser.add_argument('--mixed', type=eval, choices=[True, False], default=True)
44 |
45 | parser.add_argument('--dims', type=str, default='128-128-128-128')
46 | parser.add_argument('--act', type=str, choices=ACTIVATION_FNS.keys(), default='swish')
47 | parser.add_argument('--nblocks', type=int, default=100)
48 | parser.add_argument('--brute-force', type=eval, choices=[True, False], default=False)
49 | parser.add_argument('--actnorm', type=eval, choices=[True, False], default=False)
50 | parser.add_argument('--batchnorm', type=eval, choices=[True, False], default=False)
51 | parser.add_argument('--exact-trace', type=eval, choices=[True, False], default=False)
52 | parser.add_argument('--n-power-series', type=int, default=None)
53 | parser.add_argument('--n-samples', type=int, default=1)
54 | parser.add_argument('--n-dist', choices=['geometric', 'poisson'], default='geometric')
55 |
56 | parser.add_argument('--niters', type=int, default=50000)
57 | parser.add_argument('--batch_size', type=int, default=500)
58 | parser.add_argument('--test_batch_size', type=int, default=10000)
59 | parser.add_argument('--lr', type=float, default=1e-3)
60 | parser.add_argument('--weight-decay', type=float, default=1e-5)
61 | parser.add_argument('--annealing-iters', type=int, default=0)
62 |
63 | parser.add_argument('--save', type=str, default='experiments/iresnet_toy')
64 | parser.add_argument('--viz_freq', type=int, default=100)
65 | parser.add_argument('--val_freq', type=int, default=100)
66 | parser.add_argument('--log_freq', type=int, default=10)
67 | parser.add_argument('--gpu', type=int, default=0)
68 | parser.add_argument('--seed', type=int, default=0)
69 | args = parser.parse_args()
70 |
71 | # logger
72 | utils.makedirs(args.save)
73 | logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))
74 | logger.info(args)
75 |
76 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
77 |
78 | np.random.seed(args.seed)
79 | torch.manual_seed(args.seed)
80 | if device.type == 'cuda':
81 | torch.cuda.manual_seed(args.seed)
82 |
83 |
84 | def count_parameters(model):
85 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
86 |
87 |
88 | def standard_normal_sample(size):
89 | return torch.randn(size)
90 |
91 |
92 | def standard_normal_logprob(z):
93 | logZ = -0.5 * math.log(2 * math.pi)
94 | return logZ - z.pow(2) / 2
95 |
96 |
97 | def compute_loss(args, model, batch_size=None, beta=1.):
98 | if batch_size is None: batch_size = args.batch_size
99 |
100 | # load data
101 | x = toy_data.inf_train_gen(args.data, batch_size=batch_size)
102 | x = torch.from_numpy(x).type(torch.float32).to(device)
103 | zero = torch.zeros(x.shape[0], 1).to(x)
104 |
105 | # transform to z
106 | z, delta_logp = model(x, zero)
107 |
108 | # compute log p(z)
109 | logpz = standard_normal_logprob(z).sum(1, keepdim=True)
110 |
111 | logpx = logpz - beta * delta_logp
112 | loss = -torch.mean(logpx)
113 | return loss, torch.mean(logpz), torch.mean(-delta_logp)
114 |
115 |
116 | def parse_vnorms():
117 | ps = []
118 | for p in args.vnorms:
119 | if p == 'f':
120 | ps.append(float('inf'))
121 | else:
122 | ps.append(float(p))
123 | return ps[:-1], ps[1:]
124 |
125 |
126 | def compute_p_grads(model):
127 | scales = 0.
128 | nlayers = 0
129 | for m in model.modules():
130 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
131 | scales = scales + m.compute_one_iter()
132 | nlayers += 1
133 | scales.mul(1 / nlayers).mul(0.01).backward()
134 | for m in model.modules():
135 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
136 | if m.domain.grad is not None and torch.isnan(m.domain.grad):
137 | m.domain.grad = None
138 |
139 |
140 | def build_nnet(dims, activation_fn=torch.nn.ReLU):
141 | nnet = []
142 | domains, codomains = parse_vnorms()
143 | if args.learn_p:
144 | if args.mixed:
145 | domains = [torch.nn.Parameter(torch.tensor(0.)) for _ in domains]
146 | else:
147 | domains = [torch.nn.Parameter(torch.tensor(0.))] * len(domains)
148 | codomains = domains[1:] + [domains[0]]
149 | for i, (in_dim, out_dim, domain, codomain) in enumerate(zip(dims[:-1], dims[1:], domains, codomains)):
150 | nnet.append(activation_fn())
151 | nnet.append(
152 | base_layers.get_linear(
153 | in_dim,
154 | out_dim,
155 | coeff=args.coeff,
156 | n_iterations=args.n_lipschitz_iters,
157 | atol=args.atol,
158 | rtol=args.rtol,
159 | domain=domain,
160 | codomain=codomain,
161 | zero_init=(out_dim == 2),
162 | )
163 | )
164 | return torch.nn.Sequential(*nnet)
165 |
166 |
167 | def update_lipschitz(model, n_iterations):
168 | for m in model.modules():
169 | if isinstance(m, base_layers.SpectralNormConv2d) or isinstance(m, base_layers.SpectralNormLinear):
170 | m.compute_weight(update=True, n_iterations=n_iterations)
171 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
172 | m.compute_weight(update=True, n_iterations=n_iterations)
173 |
174 |
175 | def get_ords(model):
176 | ords = []
177 | for m in model.modules():
178 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
179 | domain, codomain = m.compute_domain_codomain()
180 | if torch.is_tensor(domain):
181 | domain = domain.item()
182 | if torch.is_tensor(codomain):
183 | codomain = codomain.item()
184 | ords.append(domain)
185 | ords.append(codomain)
186 | return ords
187 |
188 |
189 | def pretty_repr(a):
190 | return '[[' + ','.join(list(map(lambda i: f'{i:.2f}', a))) + ']]'
191 |
192 |
193 | if __name__ == '__main__':
194 |
195 | activation_fn = ACTIVATION_FNS[args.act]
196 |
197 | if args.arch == 'iresnet':
198 | dims = [2] + list(map(int, args.dims.split('-'))) + [2]
199 | blocks = []
200 | if args.actnorm: blocks.append(layers.ActNorm1d(2))
201 | for _ in range(args.nblocks):
202 | blocks.append(
203 | layers.iResBlock(
204 | build_nnet(dims, activation_fn),
205 | n_dist=args.n_dist,
206 | n_power_series=args.n_power_series,
207 | exact_trace=args.exact_trace,
208 | brute_force=args.brute_force,
209 | n_samples=args.n_samples,
210 | neumann_grad=False,
211 | grad_in_forward=False,
212 | )
213 | )
214 | if args.actnorm: blocks.append(layers.ActNorm1d(2))
215 | if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2))
216 | model = layers.SequentialFlow(blocks).to(device)
217 | elif args.arch == 'realnvp':
218 | blocks = []
219 | for _ in range(args.nblocks):
220 | blocks.append(layers.CouplingBlock(2, swap=False))
221 | blocks.append(layers.CouplingBlock(2, swap=True))
222 | if args.actnorm: blocks.append(layers.ActNorm1d(2))
223 | if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2))
224 | model = layers.SequentialFlow(blocks).to(device)
225 |
226 | logger.info(model)
227 | logger.info("Number of trainable parameters: {}".format(count_parameters(model)))
228 |
229 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
230 |
231 | time_meter = utils.RunningAverageMeter(0.93)
232 | loss_meter = utils.RunningAverageMeter(0.93)
233 | logpz_meter = utils.RunningAverageMeter(0.93)
234 | delta_logp_meter = utils.RunningAverageMeter(0.93)
235 |
236 | end = time.time()
237 | best_loss = float('inf')
238 | model.train()
239 | for itr in range(1, args.niters + 1):
240 | optimizer.zero_grad()
241 |
242 | beta = min(1, itr / args.annealing_iters) if args.annealing_iters > 0 else 1.
243 | loss, logpz, delta_logp = compute_loss(args, model, beta=beta)
244 | loss_meter.update(loss.item())
245 | logpz_meter.update(logpz.item())
246 | delta_logp_meter.update(delta_logp.item())
247 | loss.backward()
248 | if args.learn_p and itr > args.annealing_iters: compute_p_grads(model)
249 | optimizer.step()
250 | update_lipschitz(model, args.n_lipschitz_iters)
251 |
252 | time_meter.update(time.time() - end)
253 |
254 | logger.info(
255 | 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f})'
256 | ' | Logp(z) {:.6f}({:.6f}) | DeltaLogp {:.6f}({:.6f})'.format(
257 | itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, logpz_meter.val, logpz_meter.avg,
258 | delta_logp_meter.val, delta_logp_meter.avg
259 | )
260 | )
261 |
262 | if itr % args.val_freq == 0 or itr == args.niters:
263 | update_lipschitz(model, 200)
264 | with torch.no_grad():
265 | model.eval()
266 | test_loss, test_logpz, test_delta_logp = compute_loss(args, model, batch_size=args.test_batch_size)
267 | log_message = (
268 | '[TEST] Iter {:04d} | Test Loss {:.6f} '
269 | '| Test Logp(z) {:.6f} | Test DeltaLogp {:.6f}'.format(
270 | itr, test_loss.item(), test_logpz.item(), test_delta_logp.item()
271 | )
272 | )
273 | logger.info(log_message)
274 |
275 | logger.info('Ords: {}'.format(pretty_repr(get_ords(model))))
276 |
277 | if test_loss.item() < best_loss:
278 | best_loss = test_loss.item()
279 | utils.makedirs(args.save)
280 | torch.save({
281 | 'args': args,
282 | 'state_dict': model.state_dict(),
283 | }, os.path.join(args.save, 'checkpt.pth'))
284 | model.train()
285 |
286 | if itr == 1 or itr % args.viz_freq == 0:
287 | with torch.no_grad():
288 | model.eval()
289 | p_samples = toy_data.inf_train_gen(args.data, batch_size=20000)
290 |
291 | sample_fn, density_fn = model.inverse, model.forward
292 |
293 | plt.figure(figsize=(9, 3))
294 | visualize_transform(
295 | p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn,
296 | samples=True, npts=400, device=device
297 | )
298 | fig_filename = os.path.join(args.save, 'figs', '{:04d}.jpg'.format(itr))
299 | utils.makedirs(os.path.dirname(fig_filename))
300 | plt.savefig(fig_filename)
301 | plt.close()
302 | model.train()
303 |
304 | end = time.time()
305 |
306 | logger.info('Training has finished.')
307 |
--------------------------------------------------------------------------------