├── .gitignore
├── LICENSE
├── README.md
├── assets
├── 1-D.png
├── math.png
└── theory.png
├── lib
├── datasets.py
├── implicit_flow.py
├── layers
│ ├── __init__.py
│ ├── act_norm.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ ├── lipschitz.py
│ │ ├── mixed_lipschitz.py
│ │ └── utils.py
│ ├── broyden.py
│ ├── container.py
│ ├── coupling.py
│ ├── elemwise.py
│ ├── glow.py
│ ├── implicit_block.py
│ ├── iresblock.py
│ ├── mask_utils.py
│ ├── normalization.py
│ └── squeeze.py
├── lr_scheduler.py
├── optimizers.py
├── resflow.py
├── tabular.py
├── toy_data.py
├── utils.py
└── visualize_flow.py
├── preprocessing
├── convert_to_pth.py
├── create_imagenet_benchmark_datasets.py
└── extract_celeba_from_tfrecords.py
├── qualitative_samples.py
├── run_cifar10.sh
├── run_classification.sh
├── run_tabular.sh
├── run_toy.sh
├── train_classification.py
├── train_img.py
├── train_tabular.py
└── train_toy.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *__pycache__*
3 | data/*
4 | pretrained_models
5 | experiments
6 | .vscode
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Cheng Lu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Implicit Normalizing Flows (ICLR 2021 Spotlight)[[arxiv](https://arxiv.org/abs/2103.09527)][[slides](https://docs.google.com/presentation/d/1reCHEBjy9ygJ0bM_jvItEGU1jHDn7NfYBeC7FBqK-1Y/edit?usp=sharing)]
2 |
3 | This repository contains Pytorch implementation of experiments from the paper [Implicit Normalizing Flows](https://arxiv.org/abs/2103.09527). The implementation is based on [Residual Flows](https://github.com/rtqichen/residual-flows).
4 |
5 |
6 |
7 |
8 |
9 | Implicit Normalizing Flows generalize normalizing flows by allowing the invertible mapping to be **implicitly** defined by the roots of an equation F(z, x) = 0. Building on [Residual Flows](https://arxiv.org/abs/1906.02735), we propose:
10 |
11 | + A **unique** and **invertible** mapping defined by an equation of the latent variable and the observed variable.
12 | + A **more powerful function space** than Residual Flows, which relaxing the Lipschitz constraints of Residual Flows.
13 | + A **scalable algorithm** for generative modeling.
14 |
15 | As an example, in 1-D case, the function space of ImpFlows contains all strongly monotonic differentiable functions, i.e.
16 |
17 |
18 |
19 |
20 | A 1-D function fitting example:
21 |
22 |
23 |
24 |
25 |
26 | ## Requirements
27 |
28 | - PyTorch 1.4, torchvision 0.5
29 | - Python 3.6+
30 |
31 | ## Density Estimation Experiments
32 |
33 | ***NOTE***: By default, O(1)-memory gradients are enabled. However, the logged bits/dim during training will not be an actual estimate of bits/dim but whatever scalar was used to generate the unbiased gradients. If you want to check the actual bits/dim for training (and have sufficient GPU memory), set `--neumann-grad=False`. Note however that the memory cost can stochastically vary during training if this flag is `False`.
34 |
35 | Toy 2-D data:
36 | ```
37 | bash run_toy.sh
38 | ```
39 |
40 | CIFAR10:
41 | ```
42 | bash run_cifar10.sh
43 | ```
44 |
45 | Tabular Data:
46 | ```
47 | bash run_tabular.sh
48 | ```
49 |
50 | ## BibTeX
51 | ```
52 | To be done after ICLR 2021.
53 | ```
54 |
--------------------------------------------------------------------------------
/assets/1-D.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thu-ml/implicit-normalizing-flows/d8babe48bdea7b1ba98be698bef2e954ac3817ee/assets/1-D.png
--------------------------------------------------------------------------------
/assets/math.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thu-ml/implicit-normalizing-flows/d8babe48bdea7b1ba98be698bef2e954ac3817ee/assets/math.png
--------------------------------------------------------------------------------
/assets/theory.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thu-ml/implicit-normalizing-flows/d8babe48bdea7b1ba98be698bef2e954ac3817ee/assets/theory.png
--------------------------------------------------------------------------------
/lib/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.datasets as vdsets
3 |
4 |
5 | class Dataset(object):
6 |
7 | def __init__(self, loc, transform=None, in_mem=True):
8 | self.in_mem = in_mem
9 | self.dataset = torch.load(loc)
10 | if in_mem: self.dataset = self.dataset.float().div(255)
11 | self.transform = transform
12 |
13 | def __len__(self):
14 | return self.dataset.size(0)
15 |
16 | @property
17 | def ndim(self):
18 | return self.dataset.size(1)
19 |
20 | def __getitem__(self, index):
21 | x = self.dataset[index]
22 | if not self.in_mem: x = x.float().div(255)
23 | x = self.transform(x) if self.transform is not None else x
24 | return x, 0
25 |
26 |
27 | class MNIST(object):
28 |
29 | def __init__(self, dataroot, train=True, transform=None):
30 | self.mnist = vdsets.MNIST(dataroot, train=train, download=True, transform=transform)
31 |
32 | def __len__(self):
33 | return len(self.mnist)
34 |
35 | @property
36 | def ndim(self):
37 | return 1
38 |
39 | def __getitem__(self, index):
40 | return self.mnist[index]
41 |
42 |
43 | class CIFAR10(object):
44 |
45 | def __init__(self, dataroot, train=True, transform=None):
46 | self.cifar10 = vdsets.CIFAR10(dataroot, train=train, download=True, transform=transform)
47 |
48 | def __len__(self):
49 | return len(self.cifar10)
50 |
51 | @property
52 | def ndim(self):
53 | return 3
54 |
55 | def __getitem__(self, index):
56 | return self.cifar10[index]
57 |
58 |
59 | class CelebA5bit(object):
60 |
61 | LOC = 'data/celebahq64_5bit/celeba_full_64x64_5bit.pth'
62 |
63 | def __init__(self, train=True, transform=None):
64 | self.dataset = torch.load(self.LOC).float().div(31)
65 | if not train:
66 | self.dataset = self.dataset[:5000]
67 | self.transform = transform
68 |
69 | def __len__(self):
70 | return self.dataset.size(0)
71 |
72 | @property
73 | def ndim(self):
74 | return self.dataset.size(1)
75 |
76 | def __getitem__(self, index):
77 | x = self.dataset[index]
78 | x = self.transform(x) if self.transform is not None else x
79 | return x, 0
80 |
81 |
82 | class CelebAHQ(Dataset):
83 | TRAIN_LOC = 'data/celebahq/celeba256_train.pth'
84 | TEST_LOC = 'data/celebahq/celeba256_validation.pth'
85 |
86 | def __init__(self, train=True, transform=None):
87 | return super(CelebAHQ, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform)
88 |
89 |
90 | class Imagenet32(Dataset):
91 | TRAIN_LOC = 'data/imagenet32/train_32x32.pth'
92 | TEST_LOC = 'data/imagenet32/valid_32x32.pth'
93 |
94 | def __init__(self, train=True, transform=None):
95 | return super(Imagenet32, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform)
96 |
97 |
98 | class Imagenet64(Dataset):
99 | TRAIN_LOC = 'data/imagenet64/train_64x64.pth'
100 | TEST_LOC = 'data/imagenet64/valid_64x64.pth'
101 |
102 | def __init__(self, train=True, transform=None):
103 | return super(Imagenet64, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform, in_mem=False)
104 |
--------------------------------------------------------------------------------
/lib/implicit_flow.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | import lib.layers as layers
6 | import lib.layers.base as base_layers
7 |
8 | ACT_FNS = {
9 | 'softplus': lambda b: nn.Softplus(),
10 | 'elu': lambda b: nn.ELU(inplace=b),
11 | 'swish': lambda b: base_layers.Swish(),
12 | 'lcube': lambda b: base_layers.LipschitzCube(),
13 | 'identity': lambda b: base_layers.Identity(),
14 | 'relu': lambda b: nn.ReLU(inplace=b),
15 | 'sin': lambda b: base_layers.Sin(),
16 | 'zero': lambda b: base_layers.Zero(),
17 | }
18 |
19 |
20 | class ImplicitFlow(nn.Module):
21 |
22 | def __init__(
23 | self,
24 | input_size,
25 | n_blocks=[16, 16],
26 | intermediate_dim=64,
27 | factor_out=True,
28 | quadratic=False,
29 | init_layer=None,
30 | actnorm=False,
31 | fc_actnorm=False,
32 | batchnorm=False,
33 | dropout=0,
34 | fc=False,
35 | coeff=0.9,
36 | vnorms='122f',
37 | n_lipschitz_iters=None,
38 | sn_atol=None,
39 | sn_rtol=None,
40 | n_power_series=5,
41 | n_dist='geometric',
42 | n_samples=1,
43 | kernels='3-1-3',
44 | activation_fn='elu',
45 | fc_end=True,
46 | fc_idim=128,
47 | n_exact_terms=0,
48 | preact=False,
49 | neumann_grad=True,
50 | grad_in_forward=False,
51 | first_resblock=True,
52 | learn_p=False,
53 | classification=False,
54 | classification_hdim=64,
55 | n_classes=10,
56 | ):
57 | super(ImplicitFlow, self).__init__()
58 | self.n_scale = min(len(n_blocks), self._calc_n_scale(input_size))
59 | self.n_blocks = n_blocks
60 | self.intermediate_dim = intermediate_dim
61 | self.factor_out = factor_out
62 | self.quadratic = quadratic
63 | self.init_layer = init_layer
64 | self.actnorm = actnorm
65 | self.fc_actnorm = fc_actnorm
66 | self.batchnorm = batchnorm
67 | self.dropout = dropout
68 | self.fc = fc
69 | self.coeff = coeff
70 | self.vnorms = vnorms
71 | self.n_lipschitz_iters = n_lipschitz_iters
72 | self.sn_atol = sn_atol
73 | self.sn_rtol = sn_rtol
74 | self.n_power_series = n_power_series
75 | self.n_dist = n_dist
76 | self.n_samples = n_samples
77 | self.kernels = kernels
78 | self.activation_fn = activation_fn
79 | self.fc_end = fc_end
80 | self.fc_idim = fc_idim
81 | self.n_exact_terms = n_exact_terms
82 | self.preact = preact
83 | self.neumann_grad = neumann_grad
84 | self.grad_in_forward = grad_in_forward
85 | self.first_resblock = first_resblock
86 | self.learn_p = learn_p
87 | self.classification = classification
88 | self.classification_hdim = classification_hdim
89 | self.n_classes = n_classes
90 |
91 | if not self.n_scale > 0:
92 | raise ValueError('Could not compute number of scales for input of' 'size (%d,%d,%d,%d)' % input_size)
93 |
94 | self.transforms = self._build_net(input_size)
95 |
96 | self.dims = [o[1:] for o in self.calc_output_size(input_size)]
97 |
98 | if self.classification:
99 | self.build_multiscale_classifier(input_size)
100 |
101 | def _build_net(self, input_size):
102 | _, c, h, w = input_size
103 | transforms = []
104 | for i in range(self.n_scale):
105 | transforms.append(
106 | StackedImplicitBlocks(
107 | initial_size=(c, h, w),
108 | idim=self.intermediate_dim,
109 | squeeze=(i < self.n_scale - 1), # don't squeeze last layer
110 | init_layer=self.init_layer if i == 0 else None,
111 | n_blocks=self.n_blocks[i],
112 | quadratic=self.quadratic,
113 | actnorm=self.actnorm,
114 | fc_actnorm=self.fc_actnorm,
115 | batchnorm=self.batchnorm,
116 | dropout=self.dropout,
117 | fc=self.fc,
118 | coeff=self.coeff,
119 | vnorms=self.vnorms,
120 | n_lipschitz_iters=self.n_lipschitz_iters,
121 | sn_atol=self.sn_atol,
122 | sn_rtol=self.sn_rtol,
123 | n_power_series=self.n_power_series,
124 | n_dist=self.n_dist,
125 | n_samples=self.n_samples,
126 | kernels=self.kernels,
127 | activation_fn=self.activation_fn,
128 | fc_end=self.fc_end,
129 | fc_idim=self.fc_idim,
130 | n_exact_terms=self.n_exact_terms,
131 | preact=self.preact,
132 | neumann_grad=self.neumann_grad,
133 | grad_in_forward=self.grad_in_forward,
134 | first_resblock=self.first_resblock and (i == 0),
135 | learn_p=self.learn_p,
136 | )
137 | )
138 | c, h, w = c * 2 if self.factor_out else c * 4, h // 2, w // 2
139 | return nn.ModuleList(transforms)
140 |
141 | def _calc_n_scale(self, input_size):
142 | _, _, h, w = input_size
143 | n_scale = 0
144 | while h >= 4 and w >= 4:
145 | n_scale += 1
146 | h = h // 2
147 | w = w // 2
148 | return n_scale
149 |
150 | def calc_output_size(self, input_size):
151 | n, c, h, w = input_size
152 | if not self.factor_out:
153 | k = self.n_scale - 1
154 | return [[n, c * 4**k, h // 2**k, w // 2**k]]
155 | output_sizes = []
156 | for i in range(self.n_scale):
157 | if i < self.n_scale - 1:
158 | c *= 2
159 | h //= 2
160 | w //= 2
161 | output_sizes.append((n, c, h, w))
162 | else:
163 | output_sizes.append((n, c, h, w))
164 | return tuple(output_sizes)
165 |
166 | def build_multiscale_classifier(self, input_size):
167 | n, c, h, w = input_size
168 | hidden_shapes = []
169 | for i in range(self.n_scale):
170 | if i < self.n_scale - 1:
171 | c *= 2 if self.factor_out else 4
172 | h //= 2
173 | w //= 2
174 | hidden_shapes.append((n, c, h, w))
175 |
176 | classification_heads = []
177 | for i, hshape in enumerate(hidden_shapes):
178 | classification_heads.append(
179 | nn.Sequential(
180 | nn.Conv2d(hshape[1], self.classification_hdim, 3, 1, 1),
181 | layers.ActNorm2d(self.classification_hdim),
182 | nn.ReLU(inplace=True),
183 | nn.AdaptiveAvgPool2d((1, 1)),
184 | )
185 | )
186 | self.classification_heads = nn.ModuleList(classification_heads)
187 | self.logit_layer = nn.Linear(self.classification_hdim * len(classification_heads), self.n_classes)
188 |
189 | def forward(self, x, logpx=None, inverse=False, classify=False, restore=False):
190 | if inverse:
191 | return self.inverse(x, logpx)
192 | out = []
193 | if classify: class_outs = []
194 | for idx in range(len(self.transforms)):
195 | if logpx is not None:
196 | x, logpx = self.transforms[idx].forward(x, logpx, restore=restore)
197 | else:
198 | x = self.transforms[idx].forward(x, restore=restore)
199 | if self.factor_out and (idx < len(self.transforms) - 1):
200 | d = x.size(1) // 2
201 | x, f = x[:, :d], x[:, d:]
202 | out.append(f)
203 |
204 | # Handle classification.
205 | if classify:
206 | if self.factor_out:
207 | class_outs.append(self.classification_heads[idx](f))
208 | else:
209 | class_outs.append(self.classification_heads[idx](x))
210 |
211 | out.append(x)
212 | out = torch.cat([o.view(o.size()[0], -1) for o in out], 1)
213 | output = out if logpx is None else (out, logpx)
214 | if classify:
215 | h = torch.cat(class_outs, dim=1).squeeze(-1).squeeze(-1)
216 | logits = self.logit_layer(h)
217 | return output, logits
218 | else:
219 | return output
220 |
221 | def inverse(self, z, logpz=None):
222 | if self.factor_out:
223 | z = z.view(z.shape[0], -1)
224 | zs = []
225 | i = 0
226 | for dims in self.dims:
227 | s = np.prod(dims)
228 | zs.append(z[:, i:i + s])
229 | i += s
230 | zs = [_z.view(_z.size()[0], *zsize) for _z, zsize in zip(zs, self.dims)]
231 |
232 | if logpz is None:
233 | z_prev = self.transforms[-1].inverse(zs[-1])
234 | for idx in range(len(self.transforms) - 2, -1, -1):
235 | z_prev = torch.cat((z_prev, zs[idx]), dim=1)
236 | z_prev = self.transforms[idx].inverse(z_prev)
237 | return z_prev
238 | else:
239 | z_prev, logpz = self.transforms[-1].inverse(zs[-1], logpz)
240 | for idx in range(len(self.transforms) - 2, -1, -1):
241 | z_prev = torch.cat((z_prev, zs[idx]), dim=1)
242 | z_prev, logpz = self.transforms[idx].inverse(z_prev, logpz)
243 | return z_prev, logpz
244 | else:
245 | z = z.view(z.shape[0], *self.dims[-1])
246 | for idx in range(len(self.transforms) - 1, -1, -1):
247 | if logpz is None:
248 | z = self.transforms[idx].inverse(z)
249 | else:
250 | z, logpz = self.transforms[idx].inverse(z, logpz)
251 | return z if logpz is None else (z, logpz)
252 |
253 |
254 | class StackedImplicitBlocks(layers.SequentialFlow):
255 |
256 | def __init__(
257 | self,
258 | initial_size,
259 | idim,
260 | squeeze=True,
261 | init_layer=None,
262 | n_blocks=1,
263 | quadratic=False,
264 | actnorm=False,
265 | fc_actnorm=False,
266 | batchnorm=False,
267 | dropout=0,
268 | fc=False,
269 | coeff=0.9,
270 | vnorms='122f',
271 | n_lipschitz_iters=None,
272 | sn_atol=None,
273 | sn_rtol=None,
274 | n_power_series=5,
275 | n_dist='geometric',
276 | n_samples=1,
277 | kernels='3-1-3',
278 | activation_fn='elu',
279 | fc_end=True,
280 | fc_nblocks=2,
281 | fc_idim=128,
282 | n_exact_terms=0,
283 | preact=False,
284 | neumann_grad=True,
285 | grad_in_forward=False,
286 | first_resblock=True,
287 | learn_p=False,
288 | ):
289 |
290 | chain = []
291 |
292 | # Parse vnorms
293 | ps = []
294 | for p in vnorms:
295 | if p == 'f':
296 | ps.append(float('inf'))
297 | else:
298 | ps.append(float(p))
299 | domains, codomains = ps[:-1], ps[1:]
300 | assert len(domains) == len(kernels.split('-'))
301 |
302 | def _actnorm(size, fc):
303 | if fc:
304 | return FCWrapper(layers.ActNorm1d(size[0] * size[1] * size[2]))
305 | else:
306 | return layers.ActNorm2d(size[0])
307 |
308 | def _quadratic_layer(initial_size, fc):
309 | if fc:
310 | c, h, w = initial_size
311 | dim = c * h * w
312 | return FCWrapper(layers.InvertibleLinear(dim))
313 | else:
314 | return layers.InvertibleConv2d(initial_size[0])
315 |
316 | def _lipschitz_layer(fc):
317 | return base_layers.get_linear if fc else base_layers.get_conv2d
318 |
319 | def _resblock(initial_size, fc, idim=idim, first_resblock=True):
320 | if fc:
321 | return layers.imBlock(
322 | FCNet(
323 | input_shape=initial_size,
324 | idim=idim,
325 | lipschitz_layer=_lipschitz_layer(True),
326 | nhidden=len(kernels.split('-')) - 1,
327 | coeff=coeff,
328 | domains=domains,
329 | codomains=codomains,
330 | n_iterations=n_lipschitz_iters,
331 | activation_fn=activation_fn,
332 | preact=preact,
333 | dropout=dropout,
334 | sn_atol=sn_atol,
335 | sn_rtol=sn_rtol,
336 | learn_p=learn_p,
337 | ),
338 | FCNet(
339 | input_shape=initial_size,
340 | idim=idim,
341 | lipschitz_layer=_lipschitz_layer(True),
342 | nhidden=len(kernels.split('-')) - 1,
343 | coeff=coeff,
344 | domains=domains,
345 | codomains=codomains,
346 | n_iterations=n_lipschitz_iters,
347 | activation_fn=activation_fn,
348 | preact=preact,
349 | dropout=dropout,
350 | sn_atol=sn_atol,
351 | sn_rtol=sn_rtol,
352 | learn_p=learn_p,
353 | ),
354 | n_power_series=n_power_series,
355 | n_dist=n_dist,
356 | n_samples=n_samples,
357 | n_exact_terms=n_exact_terms,
358 | neumann_grad=neumann_grad,
359 | grad_in_forward=grad_in_forward,
360 | )
361 | else:
362 | def build_nnet():
363 | ks = list(map(int, kernels.split('-')))
364 | if learn_p:
365 | _domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(ks))]
366 | _codomains = _domains[1:] + [_domains[0]]
367 | else:
368 | _domains = domains
369 | _codomains = codomains
370 | nnet = []
371 | if not first_resblock and preact:
372 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
373 | nnet.append(ACT_FNS[activation_fn](False))
374 | nnet.append(
375 | _lipschitz_layer(fc)(
376 | initial_size[0], idim, ks[0], 1, ks[0] // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
377 | domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol
378 | )
379 | )
380 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
381 | nnet.append(ACT_FNS[activation_fn](True))
382 | for i, k in enumerate(ks[1:-1]):
383 | nnet.append(
384 | _lipschitz_layer(fc)(
385 | idim, idim, k, 1, k // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
386 | domain=_domains[i + 1], codomain=_codomains[i + 1], atol=sn_atol, rtol=sn_rtol
387 | )
388 | )
389 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
390 | nnet.append(ACT_FNS[activation_fn](True))
391 | if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True))
392 | nnet.append(
393 | _lipschitz_layer(fc)(
394 | idim, initial_size[0], ks[-1], 1, ks[-1] // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
395 | domain=_domains[-1], codomain=_codomains[-1], atol=sn_atol, rtol=sn_rtol
396 | )
397 | )
398 | if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
399 | return nn.Sequential(*nnet)
400 | return layers.imBlock(
401 | build_nnet(),
402 | build_nnet(),
403 | n_power_series=n_power_series,
404 | n_dist=n_dist,
405 | n_samples=n_samples,
406 | n_exact_terms=n_exact_terms,
407 | neumann_grad=neumann_grad,
408 | grad_in_forward=grad_in_forward,
409 | )
410 |
411 | if init_layer is not None: chain.append(init_layer)
412 | if first_resblock and actnorm: chain.append(_actnorm(initial_size, fc))
413 | if first_resblock and fc_actnorm: chain.append(_actnorm(initial_size, True))
414 |
415 | if squeeze:
416 | c, h, w = initial_size
417 | for i in range(n_blocks):
418 | if quadratic: chain.append(_quadratic_layer(initial_size, fc))
419 | chain.append(_resblock(initial_size, fc, first_resblock=first_resblock and (i == 0)))
420 | if actnorm: chain.append(_actnorm(initial_size, fc))
421 | if fc_actnorm: chain.append(_actnorm(initial_size, True))
422 | chain.append(layers.SqueezeLayer(2))
423 | else:
424 | for i in range(n_blocks):
425 | if quadratic: chain.append(_quadratic_layer(initial_size, fc))
426 | chain.append(_resblock(initial_size, fc, first_resblock=first_resblock and (i == 0)))
427 | if actnorm: chain.append(_actnorm(initial_size, fc))
428 | if fc_actnorm: chain.append(_actnorm(initial_size, True))
429 | # Use four fully connected layers at the end.
430 | if fc_end:
431 | for _ in range(fc_nblocks):
432 | chain.append(_resblock(initial_size, True, fc_idim))
433 | if actnorm or fc_actnorm: chain.append(_actnorm(initial_size, True))
434 | super(StackedImplicitBlocks, self).__init__(chain)
435 |
436 |
437 | class FCNet(nn.Module):
438 |
439 | def __init__(
440 | self, input_shape, idim, lipschitz_layer, nhidden, coeff, domains, codomains, n_iterations, activation_fn,
441 | preact, dropout, sn_atol, sn_rtol, learn_p, div_in=1
442 | ):
443 | super(FCNet, self).__init__()
444 | self.input_shape = input_shape
445 | c, h, w = self.input_shape
446 | dim = c * h * w
447 | nnet = []
448 | last_dim = dim // div_in
449 | if preact: nnet.append(ACT_FNS[activation_fn](False))
450 | if learn_p:
451 | domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(domains))]
452 | codomains = domains[1:] + [domains[0]]
453 | for i in range(nhidden):
454 | nnet.append(
455 | lipschitz_layer(last_dim, idim) if lipschitz_layer == nn.Linear else lipschitz_layer(
456 | last_dim, idim, coeff=coeff, n_iterations=n_iterations, domain=domains[i], codomain=codomains[i],
457 | atol=sn_atol, rtol=sn_rtol
458 | )
459 | )
460 | nnet.append(ACT_FNS[activation_fn](True))
461 | last_dim = idim
462 | if dropout: nnet.append(nn.Dropout(dropout, inplace=True))
463 | nnet.append(
464 | lipschitz_layer(last_dim, dim) if lipschitz_layer == nn.Linear else lipschitz_layer(
465 | last_dim, dim, coeff=coeff, n_iterations=n_iterations, domain=domains[-1], codomain=codomains[-1],
466 | atol=sn_atol, rtol=sn_rtol
467 | )
468 | )
469 | self.nnet = nn.Sequential(*nnet)
470 |
471 | def forward(self, x, restore=False):
472 | x = x.view(x.shape[0], -1)
473 | y = self.nnet(x)
474 | return y.view(y.shape[0], *self.input_shape)
475 |
476 |
477 | class FCWrapper(nn.Module):
478 |
479 | def __init__(self, fc_module):
480 | super(FCWrapper, self).__init__()
481 | self.fc_module = fc_module
482 |
483 | def forward(self, x, logpx=None, restore=False):
484 | shape = x.shape
485 | x = x.view(x.shape[0], -1)
486 | if logpx is None:
487 | y = self.fc_module(x)
488 | return y.view(*shape)
489 | else:
490 | y, logpy = self.fc_module(x, logpx)
491 | return y.view(*shape), logpy
492 |
493 | def inverse(self, y, logpy=None):
494 | shape = y.shape
495 | y = y.view(y.shape[0], -1)
496 | if logpy is None:
497 | x = self.fc_module.inverse(y)
498 | return x.view(*shape)
499 | else:
500 | x, logpx = self.fc_module.inverse(y, logpy)
501 | return x.view(*shape), logpx
502 |
--------------------------------------------------------------------------------
/lib/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .act_norm import *
2 | from .container import *
3 | from .coupling import *
4 | from .elemwise import *
5 | from .iresblock import *
6 | from .implicit_block import *
7 | from .normalization import *
8 | from .squeeze import *
9 | from .glow import *
10 |
--------------------------------------------------------------------------------
/lib/layers/act_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import Parameter
4 |
5 | __all__ = ['ActNorm1d', 'ActNorm2d']
6 |
7 |
8 | class ActNormNd(nn.Module):
9 |
10 | def __init__(self, num_features, eps=1e-12):
11 | super(ActNormNd, self).__init__()
12 | self.num_features = num_features
13 | self.eps = eps
14 | self.weight = Parameter(torch.Tensor(num_features))
15 | self.bias = Parameter(torch.Tensor(num_features))
16 | self.register_buffer('initialized', torch.tensor(0))
17 |
18 | @property
19 | def shape(self):
20 | raise NotImplementedError
21 |
22 | def forward(self, x, logpx=None, restore=None):
23 | c = x.size(1)
24 |
25 | if not self.initialized:
26 | with torch.no_grad():
27 | # compute batch statistics
28 | x_t = x.transpose(0, 1).contiguous().view(c, -1)
29 | batch_mean = torch.mean(x_t, dim=1)
30 | batch_var = torch.var(x_t, dim=1)
31 |
32 | # for numerical issues
33 | batch_var = torch.max(batch_var, torch.tensor(0.2).to(batch_var))
34 |
35 | self.bias.data.copy_(-batch_mean)
36 | self.weight.data.copy_(-0.5 * torch.log(batch_var))
37 | self.initialized.fill_(1)
38 |
39 | bias = self.bias.view(*self.shape).expand_as(x)
40 | weight = self.weight.view(*self.shape).expand_as(x)
41 |
42 | y = (x + bias) * torch.exp(weight)
43 |
44 | if logpx is None:
45 | return y
46 | else:
47 | return y, logpx - self._logdetgrad(x)
48 |
49 | def inverse(self, y, logpy=None):
50 | assert self.initialized
51 | bias = self.bias.view(*self.shape).expand_as(y)
52 | weight = self.weight.view(*self.shape).expand_as(y)
53 |
54 | x = y * torch.exp(-weight) - bias
55 |
56 | if logpy is None:
57 | return x
58 | else:
59 | return x, logpy + self._logdetgrad(x)
60 |
61 | def _logdetgrad(self, x):
62 | return self.weight.view(*self.shape).expand(*x.size()).contiguous().view(x.size(0), -1).sum(1, keepdim=True)
63 |
64 | def __repr__(self):
65 | return ('{name}({num_features})'.format(name=self.__class__.__name__, **self.__dict__))
66 |
67 |
68 | class ActNorm1d(ActNormNd):
69 |
70 | @property
71 | def shape(self):
72 | return [1, -1]
73 |
74 |
75 | class ActNorm2d(ActNormNd):
76 |
77 | @property
78 | def shape(self):
79 | return [1, -1, 1, 1]
80 |
--------------------------------------------------------------------------------
/lib/layers/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .activations import *
2 | from .lipschitz import *
3 | from .mixed_lipschitz import *
4 |
--------------------------------------------------------------------------------
/lib/layers/base/activations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 |
7 | class Sin(nn.Module):
8 | def __init__(self):
9 | super(Sin, self).__init__()
10 |
11 | def forward(self, x):
12 | return torch.sin(2. * math.pi * x) / math.pi * 0.5
13 |
14 |
15 | class Identity(nn.Module):
16 |
17 | def forward(self, x):
18 | return x
19 |
20 | class Zero(nn.Module):
21 |
22 | def forward(self, x):
23 | return torch.zeros_like(x).to(x)
24 |
25 | class FullSort(nn.Module):
26 |
27 | def forward(self, x):
28 | return torch.sort(x, 1)[0]
29 |
30 |
31 | class MaxMin(nn.Module):
32 |
33 | def forward(self, x):
34 | b, d = x.shape
35 | max_vals = torch.max(x.view(b, d // 2, 2), 2)[0]
36 | min_vals = torch.min(x.view(b, d // 2, 2), 2)[0]
37 | return torch.cat([max_vals, min_vals], 1)
38 |
39 |
40 | class LipschitzCube(nn.Module):
41 |
42 | def forward(self, x):
43 | return (x >= 1).to(x) * (x - 2 / 3) + (x <= -1).to(x) * (x + 2 / 3) + ((x > -1) * (x < 1)).to(x) * x**3 / 3
44 |
45 |
46 | class SwishFn(torch.autograd.Function):
47 |
48 | @staticmethod
49 | def forward(ctx, x, beta):
50 | beta_sigm = torch.sigmoid(beta * x)
51 | output = x * beta_sigm
52 | ctx.save_for_backward(x, output, beta)
53 | return output / 1.1
54 |
55 | @staticmethod
56 | def backward(ctx, grad_output):
57 | x, output, beta = ctx.saved_tensors
58 | beta_sigm = output / x
59 | grad_x = grad_output * (beta * output + beta_sigm * (1 - beta * output))
60 | grad_beta = torch.sum(grad_output * (x * output - output * output)).expand_as(beta)
61 | return grad_x / 1.1, grad_beta / 1.1
62 |
63 |
64 | class Swish(nn.Module):
65 |
66 | def __init__(self):
67 | super(Swish, self).__init__()
68 | self.beta = nn.Parameter(torch.tensor([0.5]))
69 |
70 | def forward(self, x):
71 | return (x * torch.sigmoid_(x * F.softplus(self.beta))).div_(1.1)
72 |
73 |
74 | if __name__ == '__main__':
75 |
76 | m = Swish()
77 | xx = torch.linspace(-5, 5, 1000).requires_grad_(True)
78 | yy = m(xx)
79 | dd, dbeta = torch.autograd.grad(yy.sum() * 2, [xx, m.beta])
80 |
81 | import matplotlib.pyplot as plt
82 |
83 | plt.plot(xx.detach().numpy(), yy.detach().numpy(), label='Func')
84 | plt.plot(xx.detach().numpy(), dd.detach().numpy(), label='Deriv')
85 | plt.plot(xx.detach().numpy(), torch.max(dd.detach().abs() - 1, torch.zeros_like(dd)).numpy(), label='|Deriv| > 1')
86 | plt.legend()
87 | plt.tight_layout()
88 | plt.show()
89 |
--------------------------------------------------------------------------------
/lib/layers/base/mixed_lipschitz.py:
--------------------------------------------------------------------------------
1 | from torch._six import container_abcs
2 | from itertools import repeat
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.init as init
7 | import torch.nn.functional as F
8 |
9 | __all__ = ['InducedNormLinear', 'InducedNormConv2d']
10 |
11 |
12 | class InducedNormLinear(nn.Module):
13 |
14 | def __init__(
15 | self, in_features, out_features, bias=True, coeff=0.97, domain=2, codomain=2, n_iterations=None, atol=None,
16 | rtol=None, zero_init=False, **unused_kwargs
17 | ):
18 | del unused_kwargs
19 | super(InducedNormLinear, self).__init__()
20 | self.in_features = in_features
21 | self.out_features = out_features
22 | self.coeff = coeff
23 | self.n_iterations = n_iterations
24 | self.atol = atol
25 | self.rtol = rtol
26 | self.domain = domain
27 | self.codomain = codomain
28 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
29 | if bias:
30 | self.bias = nn.Parameter(torch.Tensor(out_features))
31 | else:
32 | self.register_parameter('bias', None)
33 | self.reset_parameters(zero_init)
34 |
35 | with torch.no_grad():
36 | domain, codomain = self.compute_domain_codomain()
37 |
38 | h, w = self.weight.shape
39 | self.register_buffer('scale', torch.tensor(0.))
40 | self.register_buffer('u', normalize_u(self.weight.new_empty(h).normal_(0, 1), codomain))
41 | self.register_buffer('v', normalize_v(self.weight.new_empty(w).normal_(0, 1), domain))
42 |
43 | # Try different random seeds to find the best u and v.
44 | with torch.no_grad():
45 | self.compute_weight(True, n_iterations=200, atol=None, rtol=None)
46 | best_scale = self.scale.clone()
47 | best_u, best_v = self.u.clone(), self.v.clone()
48 | if not (domain == 2 and codomain == 2):
49 | for _ in range(10):
50 | self.register_buffer('u', normalize_u(self.weight.new_empty(h).normal_(0, 1), codomain))
51 | self.register_buffer('v', normalize_v(self.weight.new_empty(w).normal_(0, 1), domain))
52 | self.compute_weight(True, n_iterations=200)
53 | if self.scale > best_scale:
54 | best_u, best_v = self.u.clone(), self.v.clone()
55 | self.u.copy_(best_u)
56 | self.v.copy_(best_v)
57 |
58 | def reset_parameters(self, zero_init=False):
59 | init.kaiming_uniform_(self.weight, a=math.sqrt(5))
60 | if zero_init:
61 | # normalize cannot handle zero weight in some cases.
62 | self.weight.data.div_(1000)
63 | if self.bias is not None:
64 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
65 | bound = 1 / math.sqrt(fan_in)
66 | init.uniform_(self.bias, -bound, bound)
67 |
68 | def compute_domain_codomain(self):
69 | if torch.is_tensor(self.domain):
70 | domain = asym_squash(self.domain)
71 | codomain = asym_squash(self.codomain)
72 | else:
73 | domain, codomain = self.domain, self.codomain
74 | return domain, codomain
75 |
76 | def compute_one_iter(self):
77 | domain, codomain = self.compute_domain_codomain()
78 | u = self.u.detach()
79 | v = self.v.detach()
80 | weight = self.weight.detach()
81 | u = normalize_u(torch.mv(weight, v), codomain)
82 | v = normalize_v(torch.mv(weight.t(), u), domain)
83 | return torch.dot(u, torch.mv(weight, v))
84 |
85 | def compute_weight(self, update=True, n_iterations=None, atol=None, rtol=None):
86 | u = self.u
87 | v = self.v
88 | weight = self.weight
89 |
90 | if update:
91 |
92 | n_iterations = self.n_iterations if n_iterations is None else n_iterations
93 | atol = self.atol if atol is None else atol
94 | rtol = self.rtol if rtol is None else atol
95 |
96 | if n_iterations is None and (atol is None or rtol is None):
97 | raise ValueError('Need one of n_iteration or (atol, rtol).')
98 |
99 | max_itrs = 200
100 | if n_iterations is not None:
101 | max_itrs = n_iterations
102 |
103 | with torch.no_grad():
104 | domain, codomain = self.compute_domain_codomain()
105 | for _ in range(max_itrs):
106 | # Algorithm from http://www.qetlab.com/InducedMatrixNorm.
107 | if n_iterations is None and atol is not None and rtol is not None:
108 | old_v = v.clone()
109 | old_u = u.clone()
110 |
111 | u = normalize_u(torch.mv(weight, v), codomain, out=u)
112 | v = normalize_v(torch.mv(weight.t(), u), domain, out=v)
113 |
114 | if n_iterations is None and atol is not None and rtol is not None:
115 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5)
116 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5)
117 | tol_u = atol + rtol * torch.max(u)
118 | tol_v = atol + rtol * torch.max(v)
119 | if err_u < tol_u and err_v < tol_v:
120 | break
121 | self.v.copy_(v)
122 | self.u.copy_(u)
123 | u = u.clone()
124 | v = v.clone()
125 |
126 | sigma = torch.dot(u, torch.mv(weight, v))
127 | with torch.no_grad():
128 | self.scale.copy_(sigma)
129 | # soft normalization: only when sigma larger than coeff
130 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff)
131 | weight = weight / factor
132 | return weight
133 |
134 | def forward(self, input):
135 | weight = self.compute_weight(update=False)
136 | return F.linear(input, weight, self.bias)
137 |
138 | def extra_repr(self):
139 | domain, codomain = self.compute_domain_codomain()
140 | return (
141 | 'in_features={}, out_features={}, bias={}'
142 | ', coeff={}, domain={:.2f}, codomain={:.2f}, n_iters={}, atol={}, rtol={}, learnable_ord={}'.format(
143 | self.in_features, self.out_features, self.bias is not None, self.coeff, domain, codomain,
144 | self.n_iterations, self.atol, self.rtol, torch.is_tensor(self.domain)
145 | )
146 | )
147 |
148 |
149 | class InducedNormConv2d(nn.Module):
150 |
151 | def __init__(
152 | self, in_channels, out_channels, kernel_size, stride, padding, bias=True, coeff=0.97, domain=2, codomain=2,
153 | n_iterations=None, atol=None, rtol=None, **unused_kwargs
154 | ):
155 | del unused_kwargs
156 | super(InducedNormConv2d, self).__init__()
157 | self.in_channels = in_channels
158 | self.out_channels = out_channels
159 | self.kernel_size = _pair(kernel_size)
160 | self.stride = _pair(stride)
161 | self.padding = _pair(padding)
162 | self.coeff = coeff
163 | self.n_iterations = n_iterations
164 | self.domain = domain
165 | self.codomain = codomain
166 | self.atol = atol
167 | self.rtol = rtol
168 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))
169 | if bias:
170 | self.bias = nn.Parameter(torch.Tensor(out_channels))
171 | else:
172 | self.register_parameter('bias', None)
173 | self.reset_parameters()
174 | self.register_buffer('initialized', torch.tensor(0))
175 | self.register_buffer('spatial_dims', torch.tensor([1., 1.]))
176 | self.register_buffer('scale', torch.tensor(0.))
177 | self.register_buffer('u', self.weight.new_empty(self.out_channels))
178 | self.register_buffer('v', self.weight.new_empty(self.in_channels))
179 |
180 | def compute_domain_codomain(self):
181 | if torch.is_tensor(self.domain):
182 | domain = asym_squash(self.domain)
183 | codomain = asym_squash(self.codomain)
184 | else:
185 | domain, codomain = self.domain, self.codomain
186 | return domain, codomain
187 |
188 | def reset_parameters(self):
189 | init.kaiming_uniform_(self.weight, a=math.sqrt(5))
190 | if self.bias is not None:
191 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
192 | bound = 1 / math.sqrt(fan_in)
193 | init.uniform_(self.bias, -bound, bound)
194 |
195 | def _initialize_u_v(self):
196 | with torch.no_grad():
197 | domain, codomain = self.compute_domain_codomain()
198 | if self.kernel_size == (1, 1):
199 | self.u.resize_(self.out_channels).normal_(0, 1)
200 | self.u.copy_(normalize_u(self.u, codomain))
201 | self.v.resize_(self.in_channels).normal_(0, 1)
202 | self.v.copy_(normalize_v(self.v, domain))
203 | else:
204 | c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item())
205 | with torch.no_grad():
206 | num_input_dim = c * h * w
207 | self.v.resize_(num_input_dim).normal_(0, 1)
208 | self.v.copy_(normalize_v(self.v, domain))
209 | # forward call to infer the shape
210 | u = F.conv2d(
211 | self.v.view(1, c, h, w), self.weight, stride=self.stride, padding=self.padding, bias=None
212 | )
213 | num_output_dim = u.shape[0] * u.shape[1] * u.shape[2] * u.shape[3]
214 | # overwrite u with random init
215 | self.u.resize_(num_output_dim).normal_(0, 1)
216 | self.u.copy_(normalize_u(self.u, codomain))
217 |
218 | self.initialized.fill_(1)
219 |
220 | # Try different random seeds to find the best u and v.
221 | self.compute_weight(True)
222 | best_scale = self.scale.clone()
223 | best_u, best_v = self.u.clone(), self.v.clone()
224 | if not (domain == 2 and codomain == 2):
225 | for _ in range(10):
226 | if self.kernel_size == (1, 1):
227 | self.u.copy_(normalize_u(self.weight.new_empty(self.out_channels).normal_(0, 1), codomain))
228 | self.v.copy_(normalize_v(self.weight.new_empty(self.in_channels).normal_(0, 1), domain))
229 | else:
230 | self.u.copy_(normalize_u(torch.randn(num_output_dim).to(self.weight), codomain))
231 | self.v.copy_(normalize_v(torch.randn(num_input_dim).to(self.weight), domain))
232 | self.compute_weight(True, n_iterations=200)
233 | if self.scale > best_scale:
234 | best_u, best_v = self.u.clone(), self.v.clone()
235 | self.u.copy_(best_u)
236 | self.v.copy_(best_v)
237 | # These two lines are important, see https://pytorch.org/docs/master/_modules/torch/nn/utils/spectral_norm.html#spectral_norm
238 | self.u = self.u.clone(memory_format=torch.contiguous_format)
239 | self.v = self.v.clone(memory_format=torch.contiguous_format)
240 |
241 | def compute_one_iter(self):
242 | if not self.initialized:
243 | raise ValueError('Layer needs to be initialized first.')
244 | domain, codomain = self.compute_domain_codomain()
245 | if self.kernel_size == (1, 1):
246 | u = self.u.detach()
247 | v = self.v.detach()
248 | weight = self.weight.detach().view(self.out_channels, self.in_channels)
249 | u = normalize_u(torch.mv(weight, v), codomain)
250 | v = normalize_v(torch.mv(weight.t(), u), domain)
251 | return torch.dot(u, torch.mv(weight, v))
252 | else:
253 | u = self.u.detach()
254 | v = self.v.detach()
255 | weight = self.weight.detach()
256 | c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item())
257 | u_s = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None)
258 | out_shape = u_s.shape
259 | u = normalize_u(u_s.view(-1), codomain)
260 | v_s = F.conv_transpose2d(
261 | u.view(out_shape), weight, stride=self.stride, padding=self.padding, output_padding=0
262 | )
263 | v = normalize_v(v_s.view(-1), domain)
264 | weight_v = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None)
265 | return torch.dot(u.view(-1), weight_v.view(-1))
266 |
267 | def compute_weight(self, update=True, n_iterations=None, atol=None, rtol=None):
268 | if not self.initialized:
269 | self._initialize_u_v()
270 |
271 | if self.kernel_size == (1, 1):
272 | return self._compute_weight_1x1(update, n_iterations, atol, rtol)
273 | else:
274 | return self._compute_weight_kxk(update, n_iterations, atol, rtol)
275 |
276 | def _compute_weight_1x1(self, update=True, n_iterations=None, atol=None, rtol=None):
277 | n_iterations = self.n_iterations if n_iterations is None else n_iterations
278 | atol = self.atol if atol is None else atol
279 | rtol = self.rtol if rtol is None else atol
280 |
281 | if n_iterations is None and (atol is None or rtol is None):
282 | raise ValueError('Need one of n_iteration or (atol, rtol).')
283 |
284 | max_itrs = 200
285 | if n_iterations is not None:
286 | max_itrs = n_iterations
287 |
288 | u = self.u
289 | v = self.v
290 | weight = self.weight.view(self.out_channels, self.in_channels)
291 | if update:
292 | with torch.no_grad():
293 | domain, codomain = self.compute_domain_codomain()
294 | itrs_used = 0
295 | for _ in range(max_itrs):
296 | old_v = v.clone()
297 | old_u = u.clone()
298 |
299 | u = normalize_u(torch.mv(weight, v), codomain, out=u)
300 | v = normalize_v(torch.mv(weight.t(), u), domain, out=v)
301 |
302 | itrs_used = itrs_used + 1
303 |
304 | if n_iterations is None and atol is not None and rtol is not None:
305 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5)
306 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5)
307 | tol_u = atol + rtol * torch.max(u)
308 | tol_v = atol + rtol * torch.max(v)
309 | if err_u < tol_u and err_v < tol_v:
310 | break
311 | if itrs_used > 0:
312 | if domain != 1 and domain != 2:
313 | self.v.copy_(v)
314 | if codomain != 2 and codomain != float('inf'):
315 | self.u.copy_(u)
316 | # These two lines are important, see https://pytorch.org/docs/master/_modules/torch/nn/utils/spectral_norm.html#spectral_norm
317 | u = u.clone(memory_format=torch.contiguous_format)
318 | v = v.clone(memory_format=torch.contiguous_format)
319 |
320 | sigma = torch.dot(u, torch.mv(weight, v))
321 | with torch.no_grad():
322 | self.scale.copy_(sigma)
323 | # soft normalization: only when sigma larger than coeff
324 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff)
325 | weight = weight / factor
326 | return weight.view(self.out_channels, self.in_channels, 1, 1)
327 |
328 | def _compute_weight_kxk(self, update=True, n_iterations=None, atol=None, rtol=None):
329 | n_iterations = self.n_iterations if n_iterations is None else n_iterations
330 | atol = self.atol if atol is None else atol
331 | rtol = self.rtol if rtol is None else atol
332 |
333 | if n_iterations is None and (atol is None or rtol is None):
334 | raise ValueError('Need one of n_iteration or (atol, rtol).')
335 |
336 | max_itrs = 200
337 | if n_iterations is not None:
338 | max_itrs = n_iterations
339 |
340 | u = self.u
341 | v = self.v
342 | weight = self.weight
343 | c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item())
344 | if update:
345 | with torch.no_grad():
346 | domain, codomain = self.compute_domain_codomain()
347 | itrs_used = 0
348 | for _ in range(max_itrs):
349 | old_u = u.clone()
350 | old_v = v.clone()
351 |
352 | u_s = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None)
353 | out_shape = u_s.shape
354 | u = normalize_u(u_s.view(-1), codomain, out=u)
355 |
356 | v_s = F.conv_transpose2d(
357 | u.view(out_shape), weight, stride=self.stride, padding=self.padding, output_padding=0
358 | )
359 | v = normalize_v(v_s.view(-1), domain, out=v)
360 |
361 | itrs_used = itrs_used + 1
362 | if n_iterations is None and atol is not None and rtol is not None:
363 | err_u = torch.norm(u - old_u) / (u.nelement()**0.5)
364 | err_v = torch.norm(v - old_v) / (v.nelement()**0.5)
365 | tol_u = atol + rtol * torch.max(u)
366 | tol_v = atol + rtol * torch.max(v)
367 | if err_u < tol_u and err_v < tol_v:
368 | break
369 | if itrs_used > 0:
370 | if domain != 2:
371 | self.v.copy_(v)
372 | if codomain != 2:
373 | self.u.copy_(u)
374 | # These two lines are important, see https://pytorch.org/docs/master/_modules/torch/nn/utils/spectral_norm.html#spectral_norm
375 | v = v.clone(memory_format=torch.contiguous_format)
376 | u = u.clone(memory_format=torch.contiguous_format)
377 |
378 | weight_v = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None)
379 | weight_v = weight_v.view(-1)
380 | sigma = torch.dot(u.view(-1), weight_v)
381 | with torch.no_grad():
382 | self.scale.copy_(sigma)
383 | # soft normalization: only when sigma larger than coeff
384 | factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff)
385 | weight = weight / factor
386 | return weight
387 |
388 | def forward(self, input):
389 | if not self.initialized: self.spatial_dims.copy_(torch.tensor(input.shape[2:4]).to(self.spatial_dims))
390 | weight = self.compute_weight(update=False)
391 | return F.conv2d(input, weight, self.bias, self.stride, self.padding, 1, 1)
392 |
393 | def extra_repr(self):
394 | domain, codomain = self.compute_domain_codomain()
395 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}')
396 | if self.padding != (0,) * len(self.padding):
397 | s += ', padding={padding}'
398 | if self.bias is None:
399 | s += ', bias=False'
400 | s += ', coeff={}, domain={:.2f}, codomain={:.2f}, n_iters={}, atol={}, rtol={}, learnable_ord={}'.format(
401 | self.coeff, domain, codomain, self.n_iterations, self.atol, self.rtol, torch.is_tensor(self.domain)
402 | )
403 | return s.format(**self.__dict__)
404 |
405 |
406 | def projmax_(v):
407 | """Inplace argmax on absolute value."""
408 | ind = torch.argmax(torch.abs(v))
409 | v.zero_()
410 | v[ind] = 1
411 | return v
412 |
413 |
414 | def normalize_v(v, domain, out=None):
415 | if not torch.is_tensor(domain) and domain == 2:
416 | v = F.normalize(v, p=2, dim=0, out=out)
417 | elif domain == 1:
418 | v = projmax_(v)
419 | else:
420 | vabs = torch.abs(v)
421 | vph = v / vabs
422 | vph[torch.isnan(vph)] = 1
423 | vabs = vabs / torch.max(vabs)
424 | vabs = vabs**(1 / (domain - 1))
425 | v = vph * vabs / vector_norm(vabs, domain)
426 | return v
427 |
428 |
429 | def normalize_u(u, codomain, out=None):
430 | if not torch.is_tensor(codomain) and codomain == 2:
431 | u = F.normalize(u, p=2, dim=0, out=out)
432 | elif codomain == float('inf'):
433 | u = projmax_(u)
434 | else:
435 | uabs = torch.abs(u)
436 | uph = u / uabs
437 | uph[torch.isnan(uph)] = 1
438 | uabs = uabs / torch.max(uabs)
439 | uabs = uabs**(codomain - 1)
440 | if codomain == 1:
441 | u = uph * uabs / vector_norm(uabs, float('inf'))
442 | else:
443 | u = uph * uabs / vector_norm(uabs, codomain / (codomain - 1))
444 | return u
445 |
446 |
447 | def vector_norm(x, p):
448 | x = x.view(-1)
449 | return torch.sum(x**p)**(1 / p)
450 |
451 |
452 | def leaky_elu(x, a=0.3):
453 | return a * x + (1 - a) * F.elu(x)
454 |
455 |
456 | def asym_squash(x):
457 | return torch.tanh(-leaky_elu(-x + 0.5493061829986572)) * 2 + 3
458 |
459 |
460 | # def asym_squash(x):
461 | # return torch.tanh(x) / 2. + 2.
462 |
463 |
464 | def _ntuple(n):
465 |
466 | def parse(x):
467 | if isinstance(x, container_abcs.Iterable):
468 | return x
469 | return tuple(repeat(x, n))
470 |
471 | return parse
472 |
473 |
474 | _single = _ntuple(1)
475 | _pair = _ntuple(2)
476 | _triple = _ntuple(3)
477 | _quadruple = _ntuple(4)
478 |
479 | if __name__ == '__main__':
480 |
481 | p = nn.Parameter(torch.tensor(2.1))
482 |
483 | m = InducedNormConv2d(10, 2, 3, 1, 1, atol=1e-3, rtol=1e-3, domain=p, codomain=p)
484 | W = m.compute_weight()
485 |
486 | m.compute_one_iter().backward()
487 | print(p.grad)
488 |
489 | # m.weight.data.copy_(W)
490 | # W = m.compute_weight().cpu().detach().numpy()
491 | # import numpy as np
492 | # print(
493 | # '{} {} {}'.format(
494 | # np.linalg.norm(W, ord=2, axis=(0, 1)),
495 | # '>' if np.linalg.norm(W, ord=2, axis=(0, 1)) > m.scale else '<',
496 | # m.scale,
497 | # )
498 | # )
499 |
--------------------------------------------------------------------------------
/lib/layers/base/utils.py:
--------------------------------------------------------------------------------
1 | from torch._six import container_abcs
2 | from itertools import repeat
3 |
4 |
5 | def _ntuple(n):
6 |
7 | def parse(x):
8 | if isinstance(x, container_abcs.Iterable):
9 | return x
10 | return tuple(repeat(x, n))
11 |
12 | return parse
13 |
14 |
15 | _single = _ntuple(1)
16 | _pair = _ntuple(2)
17 | _triple = _ntuple(3)
18 | _quadruple = _ntuple(4)
19 |
--------------------------------------------------------------------------------
/lib/layers/broyden.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as functional
4 | from torch.autograd import Function
5 | import numpy as np
6 | import pickle
7 | import sys
8 | import os
9 | from scipy.optimize import root
10 | import time
11 | from termcolor import colored
12 |
13 | import logging
14 |
15 | logger = logging.getLogger()
16 |
17 |
18 | def _safe_norm(v):
19 | if not torch.isfinite(v).all():
20 | return np.inf
21 | return torch.norm(v)
22 |
23 |
24 | def scalar_search_armijo(phi, phi0, derphi0, c1=1e-4, alpha0=1, amin=0):
25 | ite = 0
26 | phi_a0 = phi(alpha0) # First do an update with step size 1
27 | if phi_a0 <= phi0 + c1*alpha0*derphi0:
28 | return alpha0, phi_a0, ite
29 |
30 | # Otherwise, compute the minimizer of a quadratic interpolant
31 | alpha1 = -(derphi0) * alpha0**2 / 2.0 / (phi_a0 - phi0 - derphi0 * alpha0)
32 | phi_a1 = phi(alpha1)
33 |
34 | # Otherwise loop with cubic interpolation until we find an alpha which
35 | # satisfies the first Wolfe condition (since we are backtracking, we will
36 | # assume that the value of alpha is not too small and satisfies the second
37 | # condition.
38 | while alpha1 > amin: # we are assuming alpha>0 is a descent direction
39 | factor = alpha0**2 * alpha1**2 * (alpha1-alpha0)
40 | a = alpha0**2 * (phi_a1 - phi0 - derphi0*alpha1) - \
41 | alpha1**2 * (phi_a0 - phi0 - derphi0*alpha0)
42 | a = a / factor
43 | b = -alpha0**3 * (phi_a1 - phi0 - derphi0*alpha1) + \
44 | alpha1**3 * (phi_a0 - phi0 - derphi0*alpha0)
45 | b = b / factor
46 |
47 | alpha2 = (-b + torch.sqrt(torch.abs(b**2 - 3 * a * derphi0))) / (3.0*a)
48 | phi_a2 = phi(alpha2)
49 | ite += 1
50 |
51 | if (phi_a2 <= phi0 + c1*alpha2*derphi0):
52 | return alpha2, phi_a2, ite
53 |
54 | if (alpha1 - alpha2) > alpha1 / 2.0 or (1 - alpha2/alpha1) < 0.96:
55 | alpha2 = alpha1 / 2.0
56 |
57 | alpha0 = alpha1
58 | alpha1 = alpha2
59 | phi_a0 = phi_a1
60 | phi_a1 = phi_a2
61 |
62 | # Failed to find a suitable step length
63 | return None, phi_a1, ite
64 |
65 |
66 | def line_search(update, x0, g0, g, nstep=0, on=True):
67 | """
68 | `update` is the propsoed direction of update.
69 | Code adapted from scipy.
70 | """
71 | tmp_s = [0]
72 | tmp_g0 = [g0]
73 | tmp_phi = [torch.norm(g0)**2]
74 | s_norm = torch.norm(x0) / torch.norm(update)
75 |
76 | def phi(s, store=True):
77 | if s == tmp_s[0]:
78 | return tmp_phi[0] # If the step size is so small... just return something
79 | x_est = x0 + s * update
80 | g0_new = g(x_est)
81 | phi_new = _safe_norm(g0_new)**2
82 | if store:
83 | tmp_s[0] = s
84 | tmp_g0[0] = g0_new
85 | tmp_phi[0] = phi_new
86 | return phi_new
87 |
88 | if on:
89 | s, phi1, ite = scalar_search_armijo(phi, tmp_phi[0], -tmp_phi[0], amin=1e-2)
90 | if (not on) or s is None:
91 | s = 1.0
92 | ite = 0
93 |
94 | x_est = x0 + s * update
95 | if s == tmp_s[0]:
96 | g0_new = tmp_g0[0]
97 | else:
98 | g0_new = g(x_est)
99 | return x_est, g0_new, x_est - x0, g0_new - g0, ite
100 |
101 | def rmatvec(part_Us, part_VTs, x):
102 | # Compute x^T(-I + UV^T)
103 | # x: (N, d)
104 | # part_Us: (N, d, threshold)
105 | # part_VTs: (N, threshold, d)
106 | if part_Us.nelement() == 0:
107 | return -x
108 | xTU = torch.einsum('bi, bij -> bj', x, part_Us) # (N, threshold)
109 | return -x + torch.einsum('bj, bji -> bi', xTU, part_VTs) # (N, d)
110 |
111 |
112 | def matvec(part_Us, part_VTs, x):
113 | # Compute (-I + UV^T)x
114 | # x: (N, d)
115 | # part_Us: (N, d, threshold)
116 | # part_VTs: (N, threshold, d)
117 | if part_Us.nelement() == 0:
118 | return -x
119 | VTx = torch.einsum('bji, bi -> bj', part_VTs, x) # (N, threshold)
120 | return -x + torch.einsum('bij, bj -> bi', part_Us, VTx) # (N, d)
121 |
122 |
123 | def broyden(g_, x0, threshold, eps, ls=False, name="unknown"):
124 | # LBFGS_thres = min(threshold, 20)
125 | LBFGS_thres = threshold
126 |
127 | x0_shape = x0.shape
128 | x0 = x0.view(x0_shape[0], -1)
129 |
130 | bsz, total_hsize = x0.size()
131 | eps = eps * np.sqrt(np.prod(x0.shape))
132 |
133 | def g(x):
134 | return g_(x.view(x0_shape)).view(bsz, -1)
135 |
136 | x_est = x0 # (bsz, d)
137 | gx = g(x_est) # (bsz, d)
138 | nstep = 0
139 | tnstep = 0
140 |
141 | # For fast calculation of inv_jacobian (approximately)
142 | Us = torch.zeros(bsz, total_hsize, LBFGS_thres).to(x0)
143 | VTs = torch.zeros(bsz, LBFGS_thres, total_hsize).to(x0)
144 | update = -gx
145 | new_objective = init_objective = torch.norm(gx).item()
146 | prot_break = False
147 | trace = [init_objective]
148 |
149 | # To be used in protective breaks
150 | protect_thres = 1e6
151 | lowest = new_objective
152 | lowest_xest, lowest_gx, lowest_step = x_est, gx, nstep
153 | while new_objective >= eps and nstep < threshold:
154 | x_est, gx, delta_x, delta_gx, ite = line_search(update, x_est, gx, g, nstep=nstep, on=ls)
155 | nstep += 1
156 | tnstep += (ite+1)
157 | new_objective = torch.norm(gx).item()
158 | trace.append(new_objective)
159 | if new_objective < lowest:
160 | lowest_xest, lowest_gx = x_est.clone().detach(), gx.clone().detach()
161 | lowest = new_objective
162 | lowest_step = nstep
163 | if new_objective < eps:
164 | break
165 | if new_objective < 3*eps and nstep == threshold and np.max(trace[-threshold:]) / np.min(trace[-threshold:]) < 1.3:
166 | logger.info('Iterations exceeded 30 for broyden')
167 | # if there's hardly been any progress in the last 30 steps
168 | break
169 | if new_objective > init_objective * protect_thres:
170 | logger.info('Broyden failed')
171 | prot_break = True
172 | break
173 |
174 | part_Us, part_VTs = Us[:,:,:(nstep-1) % LBFGS_thres], VTs[:,:(nstep-1) % LBFGS_thres]
175 | vT = rmatvec(part_Us, part_VTs, delta_x) # (N, d)
176 | u = (delta_x - matvec(part_Us, part_VTs, delta_gx)) / torch.einsum('bi, bi -> b', vT, delta_gx)[:,None]
177 | vT[vT != vT] = 0
178 | u[u != u] = 0
179 | VTs[:,(nstep-1) % LBFGS_thres] = vT
180 | Us[:,:,(nstep-1) % LBFGS_thres] = u
181 | update = -matvec(Us[:,:,:nstep], VTs[:,:nstep], gx)
182 |
183 | Us, VTs = None, None
184 | return {"result": lowest_xest.view(x0_shape),
185 | "nstep": nstep,
186 | "tnstep": tnstep,
187 | "lowest_step": lowest_step,
188 | "diff": torch.norm(lowest_gx).item(),
189 | "diff_detail": torch.norm(lowest_gx, dim=1),
190 | "prot_break": prot_break,
191 | "trace": trace,
192 | "eps": eps,
193 | "threshold": threshold}
194 |
195 |
196 | def analyze_broyden(res_info, err=None, judge=True, name='forward', training=True, save_err=True):
197 | """
198 | For debugging use only :-)
199 | """
200 | res_est = res_info['result']
201 | nstep = res_info['nstep']
202 | diff = res_info['diff']
203 | diff_detail = res_info['diff_detail']
204 | prot_break = res_info['prot_break']
205 | trace = res_info['trace']
206 | eps = res_info['eps']
207 | threshold = res_info['threshold']
208 | if judge:
209 | return nstep >= threshold or (nstep == 0 and (diff != diff or diff > eps)) or prot_break or torch.isnan(res_est).any()
210 |
211 | assert (err is not None), "Must provide err information when not in judgment mode"
212 | prefix, color = ('', 'red') if name == 'forward' else ('back_', 'blue')
213 | eval_prefix = '' if training else 'eval_'
214 |
215 | # Case 1: A nan entry is produced in Broyden
216 | if torch.isnan(res_est).any():
217 | msg = colored(f"WARNING: nan found in Broyden's {name} result. Diff: {diff}", color)
218 | print(msg)
219 | if save_err: pickle.dump(err, open(f'{prefix}{eval_prefix}nan.pkl', 'wb'))
220 | return (1, msg, res_info)
221 |
222 | # Case 2: Unknown problem with Broyden's method (probably due to nan update(s) to the weights)
223 | if nstep == 0 and (diff != diff or diff > eps):
224 | msg = colored(f"WARNING: Bad Broyden's method {name}. Why?? Diff: {diff}. STOP.", color)
225 | print(msg)
226 | if save_err: pickle.dump(err, open(f'{prefix}{eval_prefix}badbroyden.pkl', 'wb'))
227 | return (2, msg, res_info)
228 |
229 | # Case 3: Protective break during Broyden (so that it does not diverge to infinity)
230 | if prot_break:
231 | msg = colored(f"WARNING: Hit Protective Break in {name}. Diff: {diff}. Total Iter: {len(trace)}", color)
232 | print(msg)
233 | if save_err: pickle.dump(err, open(f'{prefix}{eval_prefix}prot_break.pkl', 'wb'))
234 | return (3, msg, res_info)
235 |
236 | return (-1, '', res_info)
--------------------------------------------------------------------------------
/lib/layers/container.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class SequentialFlow(nn.Module):
5 | """A generalized nn.Sequential container for normalizing flows.
6 | """
7 |
8 | def __init__(self, layersList):
9 | super(SequentialFlow, self).__init__()
10 | self.chain = nn.ModuleList(layersList)
11 |
12 | def forward(self, x, logpx=None, restore=False):
13 | if logpx is None:
14 | for i in range(len(self.chain)):
15 | x = self.chain[i](x, restore=restore)
16 | return x
17 | else:
18 | for i in range(len(self.chain)):
19 | x, logpx = self.chain[i](x, logpx, restore=restore)
20 | return x, logpx
21 |
22 | def inverse(self, y, logpy=None):
23 | if logpy is None:
24 | for i in range(len(self.chain) - 1, -1, -1):
25 | y = self.chain[i].inverse(y)
26 | return y
27 | else:
28 | for i in range(len(self.chain) - 1, -1, -1):
29 | y, logpy = self.chain[i].inverse(y, logpy)
30 | return y, logpy
31 |
32 |
33 | class Inverse(nn.Module):
34 |
35 | def __init__(self, flow):
36 | super(Inverse, self).__init__()
37 | self.flow = flow
38 |
39 | def forward(self, x, logpx=None):
40 | return self.flow.inverse(x, logpx)
41 |
42 | def inverse(self, y, logpy=None):
43 | return self.flow.forward(y, logpy)
44 |
--------------------------------------------------------------------------------
/lib/layers/coupling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from . import mask_utils
4 |
5 | __all__ = ['CouplingBlock', 'ChannelCouplingBlock', 'MaskedCouplingBlock']
6 |
7 |
8 | class CouplingBlock(nn.Module):
9 | """Basic coupling layer for Tensors of shape (n,d).
10 |
11 | Forward computation:
12 | y_a = x_a
13 | y_b = y_b * exp(s(x_a)) + t(x_a)
14 | Inverse computation:
15 | x_a = y_a
16 | x_b = (y_b - t(y_a)) * exp(-s(y_a))
17 | """
18 |
19 | def __init__(self, dim, nnet, swap=False):
20 | """
21 | Args:
22 | s (nn.Module)
23 | t (nn.Module)
24 | """
25 | super(CouplingBlock, self).__init__()
26 | assert (dim % 2 == 0)
27 | self.d = dim // 2
28 | self.nnet = nnet
29 | self.swap = swap
30 |
31 | def func_s_t(self, x):
32 | f = self.nnet(x)
33 | s = f[:, :self.d]
34 | t = f[:, self.d:]
35 | return s, t
36 |
37 | def forward(self, x, logpx=None):
38 | """Forward computation of a simple coupling split on the axis=1.
39 | """
40 | x_a = x[:, :self.d] if not self.swap else x[:, self.d:]
41 | x_b = x[:, self.d:] if not self.swap else x[:, :self.d]
42 | y_a, y_b, logdetgrad = self._forward_computation(x_a, x_b)
43 | y = [y_a, y_b] if not self.swap else [y_b, y_a]
44 |
45 | if logpx is None:
46 | return torch.cat(y, dim=1)
47 | else:
48 | return torch.cat(y, dim=1), logpx - logdetgrad.view(x.size(0), -1).sum(1, keepdim=True)
49 |
50 | def inverse(self, y, logpy=None):
51 | """Inverse computation of a simple coupling split on the axis=1.
52 | """
53 | y_a = y[:, :self.d] if not self.swap else y[:, self.d:]
54 | y_b = y[:, self.d:] if not self.swap else y[:, :self.d]
55 | x_a, x_b, logdetgrad = self._inverse_computation(y_a, y_b)
56 | x = [x_a, x_b] if not self.swap else [x_b, x_a]
57 | if logpy is None:
58 | return torch.cat(x, dim=1)
59 | else:
60 | return torch.cat(x, dim=1), logpy + logdetgrad
61 |
62 | def _forward_computation(self, x_a, x_b):
63 | y_a = x_a
64 | s_a, t_a = self.func_s_t(x_a)
65 | scale = torch.sigmoid(s_a + 2.)
66 | y_b = x_b * scale + t_a
67 | logdetgrad = self._logdetgrad(scale)
68 | return y_a, y_b, logdetgrad
69 |
70 | def _inverse_computation(self, y_a, y_b):
71 | x_a = y_a
72 | s_a, t_a = self.func_s_t(y_a)
73 | scale = torch.sigmoid(s_a + 2.)
74 | x_b = (y_b - t_a) / scale
75 | logdetgrad = self._logdetgrad(scale)
76 | return x_a, x_b, logdetgrad
77 |
78 | def _logdetgrad(self, scale):
79 | """
80 | Returns:
81 | Tensor (N, 1): containing ln |det J| where J is the jacobian
82 | """
83 | return torch.log(scale).view(scale.shape[0], -1).sum(1, keepdim=True)
84 |
85 | def extra_repr(self):
86 | return 'dim={d}, swap={swap}'.format(**self.__dict__)
87 |
88 |
89 | class ChannelCouplingBlock(CouplingBlock):
90 | """Channel-wise coupling layer for images.
91 | """
92 |
93 | def __init__(self, dim, nnet, mask_type='channel0'):
94 | if mask_type == 'channel0':
95 | swap = False
96 | elif mask_type == 'channel1':
97 | swap = True
98 | else:
99 | raise ValueError('Unknown mask type.')
100 | super(ChannelCouplingBlock, self).__init__(dim, nnet, swap)
101 | self.mask_type = mask_type
102 |
103 | def extra_repr(self):
104 | return 'dim={d}, mask_type={mask_type}'.format(**self.__dict__)
105 |
106 |
107 | class MaskedCouplingBlock(nn.Module):
108 | """Coupling layer for images implemented using masks.
109 | """
110 |
111 | def __init__(self, dim, nnet, mask_type='checkerboard0'):
112 | nn.Module.__init__(self)
113 | self.d = dim
114 | self.nnet = nnet
115 | self.mask_type = mask_type
116 |
117 | def func_s_t(self, x):
118 | f = self.nnet(x)
119 | s = torch.sigmoid(f[:, :self.d] + 2.)
120 | t = f[:, self.d:]
121 | return s, t
122 |
123 | def forward(self, x, logpx=None):
124 | # get mask
125 | b = mask_utils.get_mask(x, mask_type=self.mask_type)
126 |
127 | # masked forward
128 | x_a = b * x
129 | s, t = self.func_s_t(x_a)
130 | y = (x * s + t) * (1 - b) + x_a
131 |
132 | if logpx is None:
133 | return y
134 | else:
135 | return y, logpx - self._logdetgrad(s, b)
136 |
137 | def inverse(self, y, logpy=None):
138 | # get mask
139 | b = mask_utils.get_mask(y, mask_type=self.mask_type)
140 |
141 | # masked forward
142 | y_a = b * y
143 | s, t = self.func_s_t(y_a)
144 | x = y_a + (1 - b) * (y - t) / s
145 |
146 | if logpy is None:
147 | return x
148 | else:
149 | return x, logpy + self._logdetgrad(s, b)
150 |
151 | def _logdetgrad(self, s, mask):
152 | return torch.log(s).mul_(1 - mask).view(s.shape[0], -1).sum(1, keepdim=True)
153 |
154 | def extra_repr(self):
155 | return 'dim={d}, mask_type={mask_type}'.format(**self.__dict__)
156 |
--------------------------------------------------------------------------------
/lib/layers/elemwise.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | _DEFAULT_ALPHA = 1e-6
6 |
7 |
8 | class ZeroMeanTransform(nn.Module):
9 |
10 | def __init__(self):
11 | nn.Module.__init__(self)
12 |
13 | def forward(self, x, logpx=None):
14 | x = x - .5
15 | if logpx is None:
16 | return x
17 | return x, logpx
18 |
19 | def inverse(self, y, logpy=None):
20 | y = y + .5
21 | if logpy is None:
22 | return y
23 | return y, logpy
24 |
25 |
26 | class Normalize(nn.Module):
27 |
28 | def __init__(self, mean, std):
29 | nn.Module.__init__(self)
30 | self.register_buffer('mean', torch.as_tensor(mean, dtype=torch.float32))
31 | self.register_buffer('std', torch.as_tensor(std, dtype=torch.float32))
32 |
33 | def forward(self, x, logpx=None):
34 | y = x.clone()
35 | c = len(self.mean)
36 | y[:, :c].sub_(self.mean[None, :, None, None]).div_(self.std[None, :, None, None])
37 | if logpx is None:
38 | return y
39 | else:
40 | return y, logpx - self._logdetgrad(x)
41 |
42 | def inverse(self, y, logpy=None):
43 | x = y.clone()
44 | c = len(self.mean)
45 | x[:, :c].mul_(self.std[None, :, None, None]).add_(self.mean[None, :, None, None])
46 | if logpy is None:
47 | return x
48 | else:
49 | return x, logpy + self._logdetgrad(x)
50 |
51 | def _logdetgrad(self, x):
52 | logdetgrad = (
53 | self.std.abs().log().mul_(-1).view(1, -1, 1, 1).expand(x.shape[0], len(self.std), x.shape[2], x.shape[3])
54 | )
55 | return logdetgrad.reshape(x.shape[0], -1).sum(-1, keepdim=True)
56 |
57 |
58 | class LogitTransform(nn.Module):
59 | """
60 | The proprocessing step used in Real NVP:
61 | y = sigmoid(x) - a / (1 - 2a)
62 | x = logit(a + (1 - 2a)*y)
63 | """
64 |
65 | def __init__(self, alpha=_DEFAULT_ALPHA):
66 | nn.Module.__init__(self)
67 | self.alpha = alpha
68 |
69 | def forward(self, x, logpx=None, restore=False):
70 | s = self.alpha + (1 - 2 * self.alpha) * x
71 | y = torch.log(s) - torch.log(1 - s)
72 | if logpx is None:
73 | return y
74 | return y, logpx - self._logdetgrad(x).view(x.size(0), -1).sum(1, keepdim=True)
75 |
76 | def inverse(self, y, logpy=None):
77 | x = (torch.sigmoid(y) - self.alpha) / (1 - 2 * self.alpha)
78 | if logpy is None:
79 | return x
80 | return x, logpy + self._logdetgrad(x).view(x.size(0), -1).sum(1, keepdim=True)
81 |
82 | def _logdetgrad(self, x):
83 | s = self.alpha + (1 - 2 * self.alpha) * x
84 | logdetgrad = -torch.log(s - s * s) + math.log(1 - 2 * self.alpha)
85 | return logdetgrad
86 |
87 | def __repr__(self):
88 | return ('{name}({alpha})'.format(name=self.__class__.__name__, **self.__dict__))
89 |
--------------------------------------------------------------------------------
/lib/layers/glow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class InvertibleLinear(nn.Module):
7 |
8 | def __init__(self, dim):
9 | super(InvertibleLinear, self).__init__()
10 | self.dim = dim
11 | self.weight = nn.Parameter(torch.eye(dim)[torch.randperm(dim)])
12 |
13 | def forward(self, x, logpx=None):
14 | y = F.linear(x, self.weight)
15 | if logpx is None:
16 | return y
17 | else:
18 | return y, logpx - self._logdetgrad
19 |
20 | def inverse(self, y, logpy=None):
21 | x = F.linear(y, self.weight.inverse())
22 | if logpy is None:
23 | return x
24 | else:
25 | return x, logpy + self._logdetgrad
26 |
27 | @property
28 | def _logdetgrad(self):
29 | return torch.log(torch.abs(torch.det(self.weight)))
30 |
31 | def extra_repr(self):
32 | return 'dim={}'.format(self.dim)
33 |
34 |
35 | class InvertibleConv2d(nn.Module):
36 |
37 | def __init__(self, dim):
38 | super(InvertibleConv2d, self).__init__()
39 | self.dim = dim
40 | self.weight = nn.Parameter(torch.eye(dim)[torch.randperm(dim)])
41 |
42 | def forward(self, x, logpx=None):
43 | y = F.conv2d(x, self.weight.view(self.dim, self.dim, 1, 1))
44 | if logpx is None:
45 | return y
46 | else:
47 | return y, logpx - self._logdetgrad.expand_as(logpx) * x.shape[2] * x.shape[3]
48 |
49 | def inverse(self, y, logpy=None):
50 | x = F.conv2d(y, self.weight.inverse().view(self.dim, self.dim, 1, 1))
51 | if logpy is None:
52 | return x
53 | else:
54 | return x, logpy + self._logdetgrad.expand_as(logpy) * x.shape[2] * x.shape[3]
55 |
56 | @property
57 | def _logdetgrad(self):
58 | return torch.log(torch.abs(torch.det(self.weight)))
59 |
60 | def extra_repr(self):
61 | return 'dim={}'.format(self.dim)
62 |
--------------------------------------------------------------------------------
/lib/layers/implicit_block.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from torch.autograd import Function
6 | from .broyden import broyden
7 | import copy
8 | import lib.layers.base as base_layers
9 |
10 | import logging
11 |
12 | logger = logging.getLogger()
13 |
14 | __all__ = ['imBlock']
15 |
16 |
17 | def find_fixed_point(g, y, threshold=1000, eps=1e-5):
18 | x, x_prev = g(y), y
19 | i = 0
20 | tol = eps + eps * y.abs()
21 | while not torch.all((x - x_prev)**2 / tol < 1.):
22 | x, x_prev = g(x), x
23 | i += 1
24 | if i > threshold:
25 | logger.info(torch.abs(x - x_prev).max())
26 | logger.info('Iterations exceeded 1000 for fixed point.')
27 | break
28 | return x
29 |
30 | def primes(n):
31 | primfac = []
32 | d = 2
33 | while d*d <= n:
34 | while (n % d) == 0:
35 | primfac.append(d) # supposing you want multiple factors repeated
36 | n //= d
37 | d += 1
38 | if n > 1:
39 | primfac.append(n)
40 | return list(set(primfac))
41 |
42 | def choose_prime(n):
43 | if n == 32 * 32 * 3:
44 | return 3079
45 | elif n == 2:
46 | return 2
47 | else:
48 | assert False, 'Please specify the prime or the power of a prime given {}'.format(n)
49 |
50 |
51 | class RootFind(Function):
52 | @staticmethod
53 | def f(nnet_z, nnet_x, z, x):
54 | return nnet_x(x) - nnet_z(z)
55 |
56 | @staticmethod
57 | def banach_find_root(nnet_z, nnet_x, z0, x, *args):
58 | eps = args[-2]
59 | threshold = args[-1] # Can also set this to be different, based on training/inference
60 | x_embed = nnet_x(x) + x
61 | g = lambda z: x_embed - nnet_z(z)
62 | z_est = find_fixed_point(g, z0, threshold=threshold, eps=eps)
63 | if threshold > 100:
64 | torch.cuda.empty_cache()
65 | return z_est.clone().detach()
66 |
67 | @staticmethod
68 | def broyden_find_root(nnet_z, nnet_x, z0, x, *args):
69 | eps = args[-2]
70 | threshold = args[-1] # Can also set this to be different, based on training/inference
71 | x_embed = nnet_x(x) + x
72 | g = lambda z: x_embed - nnet_z(z) - z
73 | result_info = broyden(g, torch.zeros_like(z0).to(z0), threshold=threshold, eps=eps, name="forward")
74 | if result_info['prot_break']:
75 | z_est = RootFind.banach_find_root(nnet_z, nnet_x, z0, x, eps, 1000)
76 | else:
77 | z_est = result_info['result']
78 | if threshold > 100:
79 | torch.cuda.empty_cache()
80 | return z_est.clone().detach()
81 |
82 | @staticmethod
83 | def forward(ctx, nnet_z, nnet_x, z0, x, method, *args):
84 | if method == 'broyden':
85 | root_find = RootFind.broyden_find_root
86 | else:
87 | root_find = RootFind.banach_find_root
88 | ctx.args_len = len(args)
89 | with torch.no_grad():
90 | z_est = root_find(nnet_z, nnet_x, z0, x, *args)
91 |
92 | # If one would like to analyze the convergence process (e.g., failures, stability), should
93 | # insert here or in broyden_find_root.
94 | return z_est
95 |
96 | @staticmethod
97 | def backward(ctx, grad_z):
98 | assert 0, 'Cannot backward to this function.'
99 | grad_args = [None for _ in range(ctx.args_len)]
100 | return (None, None, grad_z, None, *grad_args)
101 |
102 |
103 | class imBlock(nn.Module):
104 |
105 | def __init__(
106 | self,
107 | nnet_x,
108 | nnet_z,
109 | geom_p=0.5,
110 | lamb=2.,
111 | n_power_series=None,
112 | exact_trace=False,
113 | brute_force=False,
114 | n_samples=1,
115 | n_exact_terms=2,
116 | n_exact_terms_test=20,
117 | n_dist='geometric',
118 | neumann_grad=True,
119 | grad_in_forward=True,
120 | eps_forward=1e-6,
121 | eps_backward=1e-10,
122 | eps_sample=1e-5,
123 | threshold=30,
124 | ):
125 | """
126 | Args:
127 | nnet: a nn.Module
128 | n_power_series: number of power series. If not None, uses a biased approximation to logdet.
129 | exact_trace: if False, uses a Hutchinson trace estimator. Otherwise computes the exact full Jacobian.
130 | brute_force: Computes the exact logdet. Only available for 2D inputs.
131 | """
132 | super(imBlock, self).__init__()
133 |
134 | self.nnet_x = nnet_x
135 | self.nnet_z = nnet_z
136 | self.nnet_x_copy = copy.deepcopy(self.nnet_x)
137 | self.nnet_z_copy = copy.deepcopy(self.nnet_z)
138 | for params in self.nnet_x_copy.parameters():
139 | params.requires_grad_(False)
140 | for params in self.nnet_z_copy.parameters():
141 | params.requires_grad_(False)
142 |
143 | self.n_dist = n_dist
144 | self.geom_p = nn.Parameter(torch.tensor(np.log(geom_p) - np.log(1. - geom_p))).float()
145 | self.lamb = nn.Parameter(torch.tensor(lamb)).float()
146 | self.n_samples = n_samples
147 | self.n_power_series = n_power_series
148 | self.exact_trace = exact_trace
149 | self.brute_force = brute_force
150 | self.n_exact_terms = n_exact_terms
151 | self.n_exact_terms_test = n_exact_terms_test
152 | self.grad_in_forward = grad_in_forward
153 | self.neumann_grad = neumann_grad
154 | self.eps_forward = eps_forward
155 | self.eps_backward = eps_backward
156 | self.eps_sample = eps_sample
157 | self.threshold = threshold
158 |
159 | # store the samples of n.
160 | self.register_buffer('last_n_samples', torch.zeros(self.n_samples))
161 | self.register_buffer('last_firmom', torch.zeros(1))
162 | self.register_buffer('last_secmom', torch.zeros(1))
163 |
164 |
165 | class Backward(Function):
166 | """
167 | A 'dummy' function that does nothing in the forward pass and perform implicit differentiation
168 | in the backward pass. Essentially a wrapper that provides backprop for the `imBlock` class.
169 | You should use this inner class in imBlock's forward() function by calling:
170 |
171 | self.Backward.apply(self.func, ...)
172 |
173 | """
174 | @staticmethod
175 | def forward(ctx, nnet_z, nnet_x, z, x, *args):
176 | ctx.save_for_backward(z, x)
177 | ctx.nnet_z = nnet_z
178 | ctx.nnet_x = nnet_x
179 | ctx.args = args
180 | return z
181 |
182 | @staticmethod
183 | def backward(ctx, grad):
184 | torch.cuda.empty_cache()
185 |
186 | grad = grad.clone()
187 | z, x = ctx.saved_tensors
188 | args = ctx.args
189 | eps, threshold = args[-2:]
190 |
191 | nnet_z = ctx.nnet_z
192 | nnet_x = ctx.nnet_x
193 | z = z.clone().detach().requires_grad_()
194 | x = x.clone().detach().requires_grad_()
195 |
196 | with torch.enable_grad():
197 | Fz = nnet_z(z) + z
198 |
199 | def g(x_):
200 | Fz.backward(x_, retain_graph=True) # Retain for future calls to g
201 | xJ = z.grad.clone().detach()
202 | z.grad.zero_()
203 | return xJ - grad
204 |
205 | dl_dh = torch.zeros_like(grad).to(grad)
206 | result_info = broyden(g, dl_dh, threshold=threshold, eps=eps, name="backward")
207 | dl_dh = result_info['result']
208 | Fz.backward(torch.zeros_like(dl_dh), retain_graph=False)
209 |
210 | with torch.enable_grad():
211 | Fx = nnet_x(x) + x
212 | Fx.backward(dl_dh)
213 | dl_dx = x.grad.clone().detach()
214 | x.grad.zero_()
215 |
216 | grad_args = [None for _ in range(len(args))]
217 | return (None, None, dl_dh, dl_dx, *grad_args)
218 |
219 |
220 | def forward(self, x, logpx=None, restore=False):
221 | z0 = x.clone().detach()
222 | if restore:
223 | with torch.no_grad():
224 | _ = self.nnet_x_copy(z0)
225 | _ = self.nnet_z_copy(z0)
226 | z = RootFind.apply(self.nnet_z, self.nnet_x, z0, z0, 'broyden', self.eps_forward, self.threshold)
227 | z = RootFind.f(self.nnet_z, self.nnet_x, z.detach(), z0) + z0 # For backwarding to parameters in func
228 | self.nnet_x_copy.load_state_dict(self.nnet_x.state_dict())
229 | self.nnet_z_copy.load_state_dict(self.nnet_z.state_dict())
230 | z = self.Backward.apply(self.nnet_z_copy, self.nnet_x_copy, z, x, 'broyden', self.eps_backward, self.threshold)
231 | if logpx is None:
232 | return z
233 | else:
234 | return z, logpx - self._logdetgrad(z, x)
235 |
236 | def inverse(self, z, logpy=None):
237 | x0 = z.clone().detach()
238 | x = RootFind.apply(self.nnet_x, self.nnet_z, x0, z, 'broyden', self.eps_sample, self.threshold)
239 | # x = RootFind.apply(self.nnet_x, self.nnet_z, x0, z, 'banach', self.eps_sample, self.threshold)
240 | if logpy is None:
241 | return x
242 | else:
243 | return x, logpy + self._logdetgrad(z, x)
244 |
245 | def _logdetgrad(self, z, x):
246 | """Returns logdet|dz/dx|."""
247 |
248 | with torch.enable_grad():
249 | if (self.brute_force or not self.training) and (x.ndimension() == 2 and x.shape[1] <= 10):
250 | x = x.requires_grad_(True)
251 | z = z.requires_grad_(True)
252 | Fx = x + self.nnet_x(x)
253 | Jx = batch_jacobian(Fx, x)
254 | logdet_x = torch.logdet(Jx)
255 |
256 | Fz = z + self.nnet_z(z)
257 | Jz = batch_jacobian(Fz, z)
258 | logdet_z = torch.logdet(Jz)
259 |
260 | return (logdet_x - logdet_z).view(-1, 1)
261 | if self.n_dist == 'geometric':
262 | geom_p = torch.sigmoid(self.geom_p).item()
263 | sample_fn = lambda m: geometric_sample(geom_p, m)
264 | rcdf_fn = lambda k, offset: geometric_1mcdf(geom_p, k, offset)
265 | elif self.n_dist == 'poisson':
266 | lamb = self.lamb.item()
267 | sample_fn = lambda m: poisson_sample(lamb, m)
268 | rcdf_fn = lambda k, offset: poisson_1mcdf(lamb, k, offset)
269 |
270 | if self.training:
271 | if self.n_power_series is None:
272 | # Unbiased estimation.
273 | lamb = self.lamb.item()
274 | n_samples = sample_fn(self.n_samples)
275 | n_power_series = max(n_samples) + self.n_exact_terms
276 | coeff_fn = lambda k: 1 / rcdf_fn(k, self.n_exact_terms) * \
277 | sum(n_samples >= k - self.n_exact_terms) / len(n_samples)
278 | else:
279 | # Truncated estimation.
280 | n_power_series = self.n_power_series
281 | coeff_fn = lambda k: 1.
282 | else:
283 | # Unbiased estimation with more exact terms.
284 |
285 | lamb = self.lamb.item()
286 | n_samples = sample_fn(self.n_samples)
287 | n_power_series = max(n_samples) + self.n_exact_terms_test
288 | coeff_fn = lambda k: 1 / rcdf_fn(k, self.n_exact_terms_test) * \
289 | sum(n_samples >= k - self.n_exact_terms_test) / len(n_samples)
290 |
291 | if not self.exact_trace:
292 | ####################################
293 | # Power series with trace estimator.
294 | ####################################
295 | # vareps_x = torch.randn_like(x)
296 | # vareps_z = torch.randn_like(z)
297 | vareps_x = torch.distributions.bernoulli.Bernoulli(torch.Tensor([0.5])).sample(x.shape).reshape(x.shape).to(x) * 2 - 1
298 | vareps_z = torch.distributions.bernoulli.Bernoulli(torch.Tensor([0.5])).sample(z.shape).reshape(z.shape).to(z) * 2 - 1
299 |
300 | # Choose the type of estimator.
301 | if self.training and self.neumann_grad:
302 | estimator_fn = neumann_logdet_estimator
303 | else:
304 | estimator_fn = basic_logdet_estimator
305 |
306 | # Do backprop-in-forward to save memory.
307 | if self.training and self.grad_in_forward:
308 | logdet_x = mem_eff_wrapper(
309 | estimator_fn, self.nnet_x, x, n_power_series, vareps_x, coeff_fn, self.training
310 | )
311 | logdet_z = mem_eff_wrapper(
312 | estimator_fn, self.nnet_z, z, n_power_series, vareps_z, coeff_fn, self.training
313 | )
314 | logdetgrad = logdet_x - logdet_z
315 | else:
316 | x = x.requires_grad_(True)
317 | z = z.requires_grad_(True)
318 | Fx = self.nnet_x(x)
319 | Fz = self.nnet_z(z)
320 | logdet_x = estimator_fn(Fx, x, n_power_series, vareps_x, coeff_fn, self.training)
321 | logdet_z = estimator_fn(Fz, z, n_power_series, vareps_z, coeff_fn, self.training)
322 | logdetgrad = logdet_x - logdet_z
323 | else:
324 | ############################################
325 | # Power series with exact trace computation.
326 | ############################################
327 | x = x.requires_grad_(True)
328 | z = z.requires_grad_(True)
329 | Fx = self.nnet_x(x)
330 | Jx = batch_jacobian(Fx, x)
331 | logdetJx = batch_trace(Jx)
332 | Jx_k = Jx
333 | for k in range(2, n_power_series + 1):
334 | Jx_k = torch.bmm(Jx, Jx_k)
335 | logdetJx = logdetJx + (-1)**(k+1) / k * coeff_fn(k) * batch_trace(Jx_k)
336 | Fz = self.nnet_z(z)
337 | Jz = batch_jacobian(Fz, z)
338 | logdetJz = batch_trace(Jz)
339 | Jz_k = Jz
340 | for k in range(2, n_power_series + 1):
341 | Jz_k = torch.bmm(Jz, Jz_k)
342 | logdetJz = logdetJz + (-1)**(k+1) / k * coeff_fn(k) * batch_trace(Jz_k)
343 | logdetgrad = logdetJx - logdetJz
344 |
345 | if self.training and self.n_power_series is None:
346 | self.last_n_samples.copy_(torch.tensor(n_samples).to(self.last_n_samples))
347 | estimator = logdetgrad.detach()
348 | self.last_firmom.copy_(torch.mean(estimator).to(self.last_firmom))
349 | self.last_secmom.copy_(torch.mean(estimator**2).to(self.last_secmom))
350 | return logdetgrad.view(-1, 1)
351 |
352 | def extra_repr(self):
353 | return 'dist={}, n_samples={}, n_power_series={}, neumann_grad={}, exact_trace={}, brute_force={}, neumann_grad={}, grad_in_forward={}'.format(
354 | self.n_dist, self.n_samples, self.n_power_series, self.neumann_grad, self.exact_trace, self.brute_force, self.neumann_grad, self.grad_in_forward
355 | )
356 |
357 |
358 | def batch_jacobian(g, x, create_graph=True):
359 | jac = []
360 | for d in range(g.shape[1]):
361 | jac.append(torch.autograd.grad(torch.sum(g[:, d]), x, create_graph=create_graph)[0].view(x.shape[0], 1, x.shape[1]))
362 | return torch.cat(jac, 1)
363 |
364 |
365 | def batch_trace(M):
366 | return M.view(M.shape[0], -1)[:, ::M.shape[1] + 1].sum(1)
367 |
368 |
369 |
370 | #####################
371 | # Logdet Estimators
372 | #####################
373 | class MemoryEfficientLogDetEstimator(torch.autograd.Function):
374 |
375 | @staticmethod
376 | def forward(ctx, estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *g_params):
377 | ctx.training = training
378 | with torch.enable_grad():
379 | x = x.detach().requires_grad_(True)
380 | g = gnet(x)
381 | ctx.g = g
382 | ctx.x = x
383 | logdetgrad = estimator_fn(g, x, n_power_series, vareps, coeff_fn, training)
384 |
385 | if training:
386 | grad_x, *grad_params = torch.autograd.grad(
387 | logdetgrad.sum(), (x,) + g_params, retain_graph=True, allow_unused=True
388 | )
389 | if grad_x is None:
390 | grad_x = torch.zeros_like(x)
391 | ctx.save_for_backward(grad_x, *g_params, *grad_params)
392 |
393 | return safe_detach(logdetgrad)
394 |
395 | @staticmethod
396 | def backward(ctx, grad_logdetgrad):
397 | training = ctx.training
398 | if not training:
399 | raise ValueError('Provide training=True if using backward.')
400 |
401 | with torch.enable_grad():
402 | grad_x, *params_and_grad = ctx.saved_tensors
403 | g, x = ctx.g, ctx.x
404 |
405 | # Precomputed gradients.
406 | g_params = params_and_grad[:len(params_and_grad) // 2]
407 | grad_params = params_and_grad[len(params_and_grad) // 2:]
408 |
409 | # Update based on gradient from logdetgrad.
410 | dL = grad_logdetgrad[0].detach()
411 | with torch.no_grad():
412 | grad_x.mul_(dL)
413 | grad_params = tuple([g.mul_(dL) if g is not None else None for g in grad_params])
414 |
415 | return (None, None, grad_x, None, None, None, None) + grad_params
416 |
417 |
418 | def basic_logdet_estimator(g, x, n_power_series, vareps, coeff_fn, training):
419 | vjp = vareps
420 | logdetgrad = torch.tensor(0.).to(x)
421 | for k in range(1, n_power_series + 1):
422 | vjp = torch.autograd.grad(g, x, vjp, create_graph=training, retain_graph=True)[0]
423 | tr = torch.sum(vjp.view(x.shape[0], -1) * vareps.view(x.shape[0], -1), 1)
424 | delta = (-1)**(k + 1) / k * coeff_fn(k) * tr
425 | logdetgrad = logdetgrad + delta
426 | return logdetgrad
427 |
428 |
429 | def neumann_logdet_estimator(g, x, n_power_series, vareps, coeff_fn, training):
430 | vjp = vareps
431 | neumann_vjp = vareps
432 | with torch.no_grad():
433 | for k in range(1, n_power_series + 1):
434 | vjp = torch.autograd.grad(g, x, vjp, retain_graph=True)[0]
435 | neumann_vjp = neumann_vjp + (-1)**k * coeff_fn(k) * vjp
436 | vjp_jac = torch.autograd.grad(g, x, neumann_vjp, create_graph=training)[0]
437 | logdetgrad = torch.sum(vjp_jac.view(x.shape[0], -1) * vareps.view(x.shape[0], -1), 1)
438 | return logdetgrad
439 |
440 |
441 | def mem_eff_wrapper(estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training):
442 |
443 | # We need this in order to access the variables inside this module,
444 | # since we have no other way of getting variables along the execution path.
445 | if not isinstance(gnet, nn.Module):
446 | raise ValueError('g is required to be an instance of nn.Module.')
447 |
448 | return MemoryEfficientLogDetEstimator.apply(
449 | estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *list(gnet.parameters())
450 | )
451 |
452 |
453 | # -------- Helper distribution functions --------
454 | # These take python ints or floats, not PyTorch tensors.
455 |
456 |
457 | def geometric_sample(p, n_samples):
458 | return np.random.geometric(p, n_samples)
459 |
460 |
461 | def geometric_1mcdf(p, k, offset):
462 | if k <= offset:
463 | return 1.
464 | else:
465 | k = k - offset
466 | """P(n >= k)"""
467 | return (1 - p)**max(k - 1, 0)
468 |
469 |
470 | def poisson_sample(lamb, n_samples):
471 | return np.random.poisson(lamb, n_samples)
472 |
473 |
474 | def poisson_1mcdf(lamb, k, offset):
475 | if k <= offset:
476 | return 1.
477 | else:
478 | k = k - offset
479 | """P(n >= k)"""
480 | sum = 1.
481 | for i in range(1, k):
482 | sum += lamb**i / math.factorial(i)
483 | return 1 - np.exp(-lamb) * sum
484 |
485 |
486 | def sample_rademacher_like(y):
487 | return torch.randint(low=0, high=2, size=y.shape).to(y) * 2 - 1
488 |
489 |
490 | # -------------- Helper functions --------------
491 |
492 |
493 | def safe_detach(tensor):
494 | return tensor.detach().requires_grad_(tensor.requires_grad)
495 |
496 |
497 | def _flatten(sequence):
498 | flat = [p.reshape(-1) for p in sequence]
499 | return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
500 |
501 |
502 | def _flatten_convert_none_to_zeros(sequence, like_sequence):
503 | flat = [p.reshape(-1) if p is not None else torch.zeros_like(q).view(-1) for p, q in zip(sequence, like_sequence)]
504 | return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
505 |
--------------------------------------------------------------------------------
/lib/layers/iresblock.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 |
6 | import logging
7 |
8 | logger = logging.getLogger()
9 |
10 | __all__ = ['iResBlock']
11 |
12 |
13 | class iResBlock(nn.Module):
14 |
15 | def __init__(
16 | self,
17 | nnet,
18 | geom_p=0.5,
19 | lamb=2.,
20 | n_power_series=None,
21 | exact_trace=False,
22 | brute_force=False,
23 | n_samples=1,
24 | n_exact_terms=2,
25 | n_dist='geometric',
26 | neumann_grad=True,
27 | grad_in_forward=False,
28 | ):
29 | """
30 | Args:
31 | nnet: a nn.Module
32 | n_power_series: number of power series. If not None, uses a biased approximation to logdet.
33 | exact_trace: if False, uses a Hutchinson trace estimator. Otherwise computes the exact full Jacobian.
34 | brute_force: Computes the exact logdet. Only available for 2D inputs.
35 | """
36 | nn.Module.__init__(self)
37 | self.nnet = nnet
38 | self.n_dist = n_dist
39 | self.geom_p = nn.Parameter(torch.tensor(np.log(geom_p) - np.log(1. - geom_p)))
40 | self.lamb = nn.Parameter(torch.tensor(lamb))
41 | self.n_samples = n_samples
42 | self.n_power_series = n_power_series
43 | self.exact_trace = exact_trace
44 | self.brute_force = brute_force
45 | self.n_exact_terms = n_exact_terms
46 | self.grad_in_forward = grad_in_forward
47 | self.neumann_grad = neumann_grad
48 |
49 | # store the samples of n.
50 | self.register_buffer('last_n_samples', torch.zeros(self.n_samples))
51 | self.register_buffer('last_firmom', torch.zeros(1))
52 | self.register_buffer('last_secmom', torch.zeros(1))
53 |
54 | def forward(self, x, logpx=None):
55 | if logpx is None:
56 | y = x + self.nnet(x)
57 | return y
58 | else:
59 | g, logdetgrad = self._logdetgrad(x)
60 | return x + g, logpx - logdetgrad
61 |
62 | def inverse(self, y, logpy=None):
63 | x = self._inverse_fixed_point(y)
64 | if logpy is None:
65 | return x
66 | else:
67 | return x, logpy + self._logdetgrad(x)[1]
68 |
69 | def _inverse_fixed_point(self, y, atol=1e-5, rtol=1e-5):
70 | x, x_prev = y - self.nnet(y), y
71 | i = 0
72 | tol = atol + y.abs() * rtol
73 | while not torch.all((x - x_prev)**2 / tol < 1):
74 | x, x_prev = y - self.nnet(x), x
75 | i += 1
76 | if i > 1000:
77 | logger.info('Iterations exceeded 1000 for inverse.')
78 | break
79 | return x
80 |
81 | def _logdetgrad(self, x):
82 | """Returns g(x) and logdet|d(x+g(x))/dx|."""
83 |
84 | with torch.enable_grad():
85 | if (self.brute_force or not self.training) and (x.ndimension() == 2 and x.shape[1] == 2):
86 | ###########################################
87 | # Brute-force compute Jacobian determinant.
88 | ###########################################
89 | x = x.requires_grad_(True)
90 | g = self.nnet(x)
91 | # Brute-force logdet only available for 2D.
92 | jac = batch_jacobian(g, x)
93 | batch_dets = (jac[:, 0, 0] + 1) * (jac[:, 1, 1] + 1) - jac[:, 0, 1] * jac[:, 1, 0]
94 | return g, torch.log(torch.abs(batch_dets)).view(-1, 1)
95 |
96 | if self.n_dist == 'geometric':
97 | geom_p = torch.sigmoid(self.geom_p).item()
98 | sample_fn = lambda m: geometric_sample(geom_p, m)
99 | rcdf_fn = lambda k, offset: geometric_1mcdf(geom_p, k, offset)
100 | elif self.n_dist == 'poisson':
101 | lamb = self.lamb.item()
102 | sample_fn = lambda m: poisson_sample(lamb, m)
103 | rcdf_fn = lambda k, offset: poisson_1mcdf(lamb, k, offset)
104 |
105 | if self.training:
106 | if self.n_power_series is None:
107 | # Unbiased estimation.
108 | lamb = self.lamb.item()
109 | n_samples = sample_fn(self.n_samples)
110 | n_power_series = max(n_samples) + self.n_exact_terms
111 | coeff_fn = lambda k: 1 / rcdf_fn(k, self.n_exact_terms) * \
112 | sum(n_samples >= k - self.n_exact_terms) / len(n_samples)
113 | else:
114 | # Truncated estimation.
115 | n_power_series = self.n_power_series
116 | coeff_fn = lambda k: 1.
117 | else:
118 | # Unbiased estimation with more exact terms.
119 | lamb = self.lamb.item()
120 | n_samples = sample_fn(self.n_samples)
121 | n_power_series = max(n_samples) + 20
122 | coeff_fn = lambda k: 1 / rcdf_fn(k, 20) * \
123 | sum(n_samples >= k - 20) / len(n_samples)
124 |
125 | if not self.exact_trace:
126 | ####################################
127 | # Power series with trace estimator.
128 | ####################################
129 | vareps = torch.randn_like(x)
130 |
131 | # Choose the type of estimator.
132 | if self.training and self.neumann_grad:
133 | estimator_fn = neumann_logdet_estimator
134 | else:
135 | estimator_fn = basic_logdet_estimator
136 |
137 | # Do backprop-in-forward to save memory.
138 | if self.training and self.grad_in_forward:
139 | g, logdetgrad = mem_eff_wrapper(
140 | estimator_fn, self.nnet, x, n_power_series, vareps, coeff_fn, self.training
141 | )
142 | else:
143 | x = x.requires_grad_(True)
144 | g = self.nnet(x)
145 | logdetgrad = estimator_fn(g, x, n_power_series, vareps, coeff_fn, self.training)
146 | else:
147 | ############################################
148 | # Power series with exact trace computation.
149 | ############################################
150 | x = x.requires_grad_(True)
151 | g = self.nnet(x)
152 | jac = batch_jacobian(g, x)
153 | logdetgrad = batch_trace(jac)
154 | jac_k = jac
155 | for k in range(2, n_power_series + 1):
156 | jac_k = torch.bmm(jac, jac_k)
157 | logdetgrad = logdetgrad + (-1)**(k+1) / k * coeff_fn(k) * batch_trace(jac_k)
158 |
159 | if self.training and self.n_power_series is None:
160 | self.last_n_samples.copy_(torch.tensor(n_samples).to(self.last_n_samples))
161 | estimator = logdetgrad.detach()
162 | self.last_firmom.copy_(torch.mean(estimator).to(self.last_firmom))
163 | self.last_secmom.copy_(torch.mean(estimator**2).to(self.last_secmom))
164 | return g, logdetgrad.view(-1, 1)
165 |
166 | def extra_repr(self):
167 | return 'dist={}, n_samples={}, n_power_series={}, neumann_grad={}, exact_trace={}, brute_force={}'.format(
168 | self.n_dist, self.n_samples, self.n_power_series, self.neumann_grad, self.exact_trace, self.brute_force
169 | )
170 |
171 |
172 | def batch_jacobian(g, x):
173 | jac = []
174 | for d in range(g.shape[1]):
175 | jac.append(torch.autograd.grad(torch.sum(g[:, d]), x, create_graph=True)[0].view(x.shape[0], 1, x.shape[1]))
176 | return torch.cat(jac, 1)
177 |
178 |
179 | def batch_trace(M):
180 | return M.view(M.shape[0], -1)[:, ::M.shape[1] + 1].sum(1)
181 |
182 |
183 | #####################
184 | # Logdet Estimators
185 | #####################
186 | class MemoryEfficientLogDetEstimator(torch.autograd.Function):
187 |
188 | @staticmethod
189 | def forward(ctx, estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *g_params):
190 | ctx.training = training
191 | with torch.enable_grad():
192 | x = x.detach().requires_grad_(True)
193 | g = gnet(x)
194 | ctx.g = g
195 | ctx.x = x
196 | logdetgrad = estimator_fn(g, x, n_power_series, vareps, coeff_fn, training)
197 |
198 | if training:
199 | grad_x, *grad_params = torch.autograd.grad(
200 | logdetgrad.sum(), (x,) + g_params, retain_graph=True, allow_unused=True
201 | )
202 | if grad_x is None:
203 | grad_x = torch.zeros_like(x)
204 | ctx.save_for_backward(grad_x, *g_params, *grad_params)
205 |
206 | return safe_detach(g), safe_detach(logdetgrad)
207 |
208 | @staticmethod
209 | def backward(ctx, grad_g, grad_logdetgrad):
210 | training = ctx.training
211 | if not training:
212 | raise ValueError('Provide training=True if using backward.')
213 |
214 | with torch.enable_grad():
215 | grad_x, *params_and_grad = ctx.saved_tensors
216 | g, x = ctx.g, ctx.x
217 |
218 | # Precomputed gradients.
219 | g_params = params_and_grad[:len(params_and_grad) // 2]
220 | grad_params = params_and_grad[len(params_and_grad) // 2:]
221 |
222 | dg_x, *dg_params = torch.autograd.grad(g, [x] + g_params, grad_g, allow_unused=True)
223 |
224 | # Update based on gradient from logdetgrad.
225 | dL = grad_logdetgrad[0].detach()
226 | with torch.no_grad():
227 | grad_x.mul_(dL)
228 | grad_params = tuple([g.mul_(dL) if g is not None else None for g in grad_params])
229 |
230 | # Update based on gradient from g.
231 | with torch.no_grad():
232 | grad_x.add_(dg_x)
233 | grad_params = tuple([dg.add_(djac) if djac is not None else dg for dg, djac in zip(dg_params, grad_params)])
234 |
235 | return (None, None, grad_x, None, None, None, None) + grad_params
236 |
237 |
238 | def basic_logdet_estimator(g, x, n_power_series, vareps, coeff_fn, training):
239 | vjp = vareps
240 | logdetgrad = torch.tensor(0.).to(x)
241 | for k in range(1, n_power_series + 1):
242 | vjp = torch.autograd.grad(g, x, vjp, create_graph=training, retain_graph=True)[0]
243 | tr = torch.sum(vjp.view(x.shape[0], -1) * vareps.view(x.shape[0], -1), 1)
244 | delta = (-1)**(k + 1) / k * coeff_fn(k) * tr
245 | logdetgrad = logdetgrad + delta
246 | return logdetgrad
247 |
248 |
249 | def neumann_logdet_estimator(g, x, n_power_series, vareps, coeff_fn, training):
250 | vjp = vareps
251 | neumann_vjp = vareps
252 | with torch.no_grad():
253 | for k in range(1, n_power_series + 1):
254 | vjp = torch.autograd.grad(g, x, vjp, retain_graph=True)[0]
255 | neumann_vjp = neumann_vjp + (-1)**k * coeff_fn(k) * vjp
256 | vjp_jac = torch.autograd.grad(g, x, neumann_vjp, create_graph=training)[0]
257 | logdetgrad = torch.sum(vjp_jac.view(x.shape[0], -1) * vareps.view(x.shape[0], -1), 1)
258 | return logdetgrad
259 |
260 |
261 | def mem_eff_wrapper(estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training):
262 |
263 | # We need this in order to access the variables inside this module,
264 | # since we have no other way of getting variables along the execution path.
265 | if not isinstance(gnet, nn.Module):
266 | raise ValueError('g is required to be an instance of nn.Module.')
267 |
268 | return MemoryEfficientLogDetEstimator.apply(
269 | estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *list(gnet.parameters())
270 | )
271 |
272 |
273 | # -------- Helper distribution functions --------
274 | # These take python ints or floats, not PyTorch tensors.
275 |
276 |
277 | def geometric_sample(p, n_samples):
278 | return np.random.geometric(p, n_samples)
279 |
280 |
281 | def geometric_1mcdf(p, k, offset):
282 | if k <= offset:
283 | return 1.
284 | else:
285 | k = k - offset
286 | """P(n >= k)"""
287 | return (1 - p)**max(k - 1, 0)
288 |
289 |
290 | def poisson_sample(lamb, n_samples):
291 | return np.random.poisson(lamb, n_samples)
292 |
293 |
294 | def poisson_1mcdf(lamb, k, offset):
295 | if k <= offset:
296 | return 1.
297 | else:
298 | k = k - offset
299 | """P(n >= k)"""
300 | sum = 1.
301 | for i in range(1, k):
302 | sum += lamb**i / math.factorial(i)
303 | return 1 - np.exp(-lamb) * sum
304 |
305 |
306 | def sample_rademacher_like(y):
307 | return torch.randint(low=0, high=2, size=y.shape).to(y) * 2 - 1
308 |
309 |
310 | # -------------- Helper functions --------------
311 |
312 |
313 | def safe_detach(tensor):
314 | return tensor.detach().requires_grad_(tensor.requires_grad)
315 |
316 |
317 | def _flatten(sequence):
318 | flat = [p.reshape(-1) for p in sequence]
319 | return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
320 |
321 |
322 | def _flatten_convert_none_to_zeros(sequence, like_sequence):
323 | flat = [p.reshape(-1) if p is not None else torch.zeros_like(q).view(-1) for p, q in zip(sequence, like_sequence)]
324 | return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
325 |
--------------------------------------------------------------------------------
/lib/layers/mask_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def _get_checkerboard_mask(x, swap=False):
5 | n, c, h, w = x.size()
6 |
7 | H = ((h - 1) // 2 + 1) * 2 # H = h + 1 if h is odd and h if h is even
8 | W = ((w - 1) // 2 + 1) * 2
9 |
10 | # construct checkerboard mask
11 | if not swap:
12 | mask = torch.Tensor([[1, 0], [0, 1]]).repeat(H // 2, W // 2)
13 | else:
14 | mask = torch.Tensor([[0, 1], [1, 0]]).repeat(H // 2, W // 2)
15 | mask = mask[:h, :w]
16 | mask = mask.contiguous().view(1, 1, h, w).expand(n, c, h, w).type_as(x.data)
17 |
18 | return mask
19 |
20 |
21 | def _get_channel_mask(x, swap=False):
22 | n, c, h, w = x.size()
23 | assert (c % 2 == 0)
24 |
25 | # construct channel-wise mask
26 | mask = torch.zeros(x.size())
27 | if not swap:
28 | mask[:, :c // 2] = 1
29 | else:
30 | mask[:, c // 2:] = 1
31 | return mask
32 |
33 |
34 | def get_mask(x, mask_type=None):
35 | if mask_type is None:
36 | return torch.zeros(x.size()).to(x)
37 | elif mask_type == 'channel0':
38 | return _get_channel_mask(x, swap=False)
39 | elif mask_type == 'channel1':
40 | return _get_channel_mask(x, swap=True)
41 | elif mask_type == 'checkerboard0':
42 | return _get_checkerboard_mask(x, swap=False)
43 | elif mask_type == 'checkerboard1':
44 | return _get_checkerboard_mask(x, swap=True)
45 | else:
46 | raise ValueError('Unknown mask type {}'.format(mask_type))
47 |
--------------------------------------------------------------------------------
/lib/layers/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import Parameter
4 |
5 | __all__ = ['MovingBatchNorm1d', 'MovingBatchNorm2d']
6 |
7 |
8 | class MovingBatchNormNd(nn.Module):
9 |
10 | def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True):
11 | super(MovingBatchNormNd, self).__init__()
12 | self.num_features = num_features
13 | self.affine = affine
14 | self.eps = eps
15 | self.decay = decay
16 | self.bn_lag = bn_lag
17 | self.register_buffer('step', torch.zeros(1))
18 | if self.affine:
19 | self.bias = Parameter(torch.Tensor(num_features))
20 | else:
21 | self.register_parameter('bias', None)
22 | self.register_buffer('running_mean', torch.zeros(num_features))
23 | self.reset_parameters()
24 |
25 | @property
26 | def shape(self):
27 | raise NotImplementedError
28 |
29 | def reset_parameters(self):
30 | self.running_mean.zero_()
31 | if self.affine:
32 | self.bias.data.zero_()
33 |
34 | def forward(self, x, logpx=None):
35 | c = x.size(1)
36 | used_mean = self.running_mean.clone().detach()
37 |
38 | if self.training:
39 | # compute batch statistics
40 | x_t = x.transpose(0, 1).contiguous().view(c, -1)
41 | batch_mean = torch.mean(x_t, dim=1)
42 |
43 | # moving average
44 | if self.bn_lag > 0:
45 | used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach())
46 | used_mean /= (1. - self.bn_lag**(self.step[0] + 1))
47 |
48 | # update running estimates
49 | self.running_mean -= self.decay * (self.running_mean - batch_mean.data)
50 | self.step += 1
51 |
52 | # perform normalization
53 | used_mean = used_mean.view(*self.shape).expand_as(x)
54 |
55 | y = x - used_mean
56 |
57 | if self.affine:
58 | bias = self.bias.view(*self.shape).expand_as(x)
59 | y = y + bias
60 |
61 | if logpx is None:
62 | return y
63 | else:
64 | return y, logpx
65 |
66 | def inverse(self, y, logpy=None):
67 | used_mean = self.running_mean
68 |
69 | if self.affine:
70 | bias = self.bias.view(*self.shape).expand_as(y)
71 | y = y - bias
72 |
73 | used_mean = used_mean.view(*self.shape).expand_as(y)
74 | x = y + used_mean
75 |
76 | if logpy is None:
77 | return x
78 | else:
79 | return x, logpy
80 |
81 | def __repr__(self):
82 | return (
83 | '{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},'
84 | ' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__)
85 | )
86 |
87 |
88 | class MovingBatchNorm1d(MovingBatchNormNd):
89 |
90 | @property
91 | def shape(self):
92 | return [1, -1]
93 |
94 |
95 | class MovingBatchNorm2d(MovingBatchNormNd):
96 |
97 | @property
98 | def shape(self):
99 | return [1, -1, 1, 1]
100 |
--------------------------------------------------------------------------------
/lib/layers/squeeze.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | __all__ = ['SqueezeLayer']
5 |
6 |
7 | class SqueezeLayer(nn.Module):
8 |
9 | def __init__(self, downscale_factor):
10 | super(SqueezeLayer, self).__init__()
11 | self.downscale_factor = downscale_factor
12 |
13 | def forward(self, x, logpx=None, restore=False):
14 | squeeze_x = squeeze(x, self.downscale_factor)
15 | if logpx is None:
16 | return squeeze_x
17 | else:
18 | return squeeze_x, logpx
19 |
20 | def inverse(self, y, logpy=None):
21 | unsqueeze_y = unsqueeze(y, self.downscale_factor)
22 | if logpy is None:
23 | return unsqueeze_y
24 | else:
25 | return unsqueeze_y, logpy
26 |
27 |
28 | def unsqueeze(input, upscale_factor=2):
29 | return torch.pixel_shuffle(input, upscale_factor)
30 |
31 |
32 | def squeeze(input, downscale_factor=2):
33 | '''
34 | [:, C, H*r, W*r] -> [:, C*r^2, H, W]
35 | '''
36 | batch_size, in_channels, in_height, in_width = input.shape
37 | out_channels = in_channels * (downscale_factor**2)
38 |
39 | out_height = in_height // downscale_factor
40 | out_width = in_width // downscale_factor
41 |
42 | input_view = input.reshape(batch_size, in_channels, out_height, downscale_factor, out_width, downscale_factor)
43 |
44 | output = input_view.permute(0, 1, 3, 5, 2, 4)
45 | return output.reshape(batch_size, out_channels, out_height, out_width)
46 |
--------------------------------------------------------------------------------
/lib/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.optim.lr_scheduler import _LRScheduler
3 |
4 |
5 | class CosineAnnealingWarmRestarts(_LRScheduler):
6 | r"""Set the learning rate of each parameter group using a cosine annealing
7 | schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
8 | is the number of epochs since the last restart and :math:`T_{i}` is the number
9 | of epochs between two warm restarts in SGDR:
10 | .. math::
11 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
12 | \cos(\frac{T_{cur}}{T_{i}}\pi))
13 | When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
14 | When :math:`T_{cur}=0`(after restart), set :math:`\eta_t=\eta_{max}`.
15 | It has been proposed in
16 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
17 | Args:
18 | optimizer (Optimizer): Wrapped optimizer.
19 | T_0 (int): Number of iterations for the first restart.
20 | T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
21 | eta_min (float, optional): Minimum learning rate. Default: 0.
22 | last_epoch (int, optional): The index of last epoch. Default: -1.
23 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
24 | https://arxiv.org/abs/1608.03983
25 | """
26 |
27 | def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1):
28 | if T_0 <= 0 or not isinstance(T_0, int):
29 | raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
30 | if T_mult < 1 or not isinstance(T_mult, int):
31 | raise ValueError("Expected integer T_mul >= 1, but got {}".format(T_mult))
32 | self.T_0 = T_0
33 | self.T_i = T_0
34 | self.T_mult = T_mult
35 | self.eta_min = eta_min
36 | super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch)
37 | self.T_cur = last_epoch
38 |
39 | def get_lr(self):
40 | return [
41 | self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
42 | for base_lr in self.base_lrs
43 | ]
44 |
45 | def step(self, epoch=None):
46 | """Step could be called after every update, i.e. if one epoch has 10 iterations
47 | (number_of_train_examples / batch_size), we should call SGDR.step(0.1), SGDR.step(0.2), etc.
48 | This function can be called in an interleaved way.
49 | Example:
50 | >>> scheduler = SGDR(optimizer, T_0, T_mult)
51 | >>> for epoch in range(20):
52 | >>> scheduler.step()
53 | >>> scheduler.step(26)
54 | >>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
55 | """
56 | if epoch is None:
57 | epoch = self.last_epoch + 1
58 | self.T_cur = self.T_cur + 1
59 | if self.T_cur >= self.T_i:
60 | self.T_cur = self.T_cur - self.T_i
61 | self.T_i = self.T_i * self.T_mult
62 | else:
63 | if epoch >= self.T_0:
64 | if self.T_mult == 1:
65 | self.T_cur = epoch % self.T_0
66 | else:
67 | n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
68 | self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / (self.T_mult - 1)
69 | self.T_i = self.T_0 * self.T_mult**(n)
70 | else:
71 | self.T_i = self.T_0
72 | self.T_cur = epoch
73 | self.last_epoch = math.floor(epoch)
74 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
75 | param_group['lr'] = lr
76 |
--------------------------------------------------------------------------------
/lib/optimizers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 |
5 |
6 | class Adam(Optimizer):
7 | """Implements Adam algorithm.
8 |
9 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
10 |
11 | Arguments:
12 | params (iterable): iterable of parameters to optimize or dicts defining
13 | parameter groups
14 | lr (float, optional): learning rate (default: 1e-3)
15 | betas (Tuple[float, float], optional): coefficients used for computing
16 | running averages of gradient and its square (default: (0.9, 0.999))
17 | eps (float, optional): term added to the denominator to improve
18 | numerical stability (default: 1e-8)
19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
20 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
21 | algorithm from the paper `On the Convergence of Adam and Beyond`_
22 | (default: False)
23 |
24 | .. _Adam\: A Method for Stochastic Optimization:
25 | https://arxiv.org/abs/1412.6980
26 | .. _On the Convergence of Adam and Beyond:
27 | https://openreview.net/forum?id=ryQu7f-RZ
28 | """
29 |
30 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False):
31 | if not 0.0 <= lr:
32 | raise ValueError("Invalid learning rate: {}".format(lr))
33 | if not 0.0 <= eps:
34 | raise ValueError("Invalid epsilon value: {}".format(eps))
35 | if not 0.0 <= betas[0] < 1.0:
36 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
37 | if not 0.0 <= betas[1] < 1.0:
38 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
39 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
40 | super(Adam, self).__init__(params, defaults)
41 |
42 | def __setstate__(self, state):
43 | super(Adam, self).__setstate__(state)
44 | for group in self.param_groups:
45 | group.setdefault('amsgrad', False)
46 |
47 | def step(self, closure=None):
48 | """Performs a single optimization step.
49 |
50 | Arguments:
51 | closure (callable, optional): A closure that reevaluates the model
52 | and returns the loss.
53 | """
54 | loss = None
55 | if closure is not None:
56 | loss = closure()
57 |
58 | for group in self.param_groups:
59 | for p in group['params']:
60 | if p.grad is None:
61 | continue
62 | grad = p.grad.data
63 | if grad.is_sparse:
64 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
65 | amsgrad = group['amsgrad']
66 |
67 | state = self.state[p]
68 |
69 | # State initialization
70 | if len(state) == 0:
71 | state['step'] = 0
72 | # Exponential moving average of gradient values
73 | state['exp_avg'] = torch.zeros_like(p.data)
74 | # Exponential moving average of squared gradient values
75 | state['exp_avg_sq'] = torch.zeros_like(p.data)
76 | if amsgrad:
77 | # Maintains max of all exp. moving avg. of sq. grad. values
78 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
79 |
80 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
81 | if amsgrad:
82 | max_exp_avg_sq = state['max_exp_avg_sq']
83 | beta1, beta2 = group['betas']
84 |
85 | state['step'] += 1
86 |
87 | # Decay the first and second moment running average coefficient
88 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
89 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
90 | if amsgrad:
91 | # Maintains the maximum of all 2nd moment running avg. till now
92 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
93 | # Use the max. for normalizing running avg. of gradient
94 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
95 | else:
96 | denom = exp_avg_sq.sqrt().add_(group['eps'])
97 |
98 | bias_correction1 = 1 - beta1**state['step']
99 | bias_correction2 = 1 - beta2**state['step']
100 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
101 |
102 | p.data.addcdiv_(-step_size, exp_avg, denom)
103 |
104 | if group['weight_decay'] != 0:
105 | p.data.add(-step_size * group['weight_decay'], p.data)
106 |
107 | return loss
108 |
109 |
110 | class Adamax(Optimizer):
111 | """Implements Adamax algorithm (a variant of Adam based on infinity norm).
112 |
113 | It has been proposed in `Adam: A Method for Stochastic Optimization`__.
114 |
115 | Arguments:
116 | params (iterable): iterable of parameters to optimize or dicts defining
117 | parameter groups
118 | lr (float, optional): learning rate (default: 2e-3)
119 | betas (Tuple[float, float], optional): coefficients used for computing
120 | running averages of gradient and its square
121 | eps (float, optional): term added to the denominator to improve
122 | numerical stability (default: 1e-8)
123 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
124 |
125 | __ https://arxiv.org/abs/1412.6980
126 | """
127 |
128 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
129 | if not 0.0 <= lr:
130 | raise ValueError("Invalid learning rate: {}".format(lr))
131 | if not 0.0 <= eps:
132 | raise ValueError("Invalid epsilon value: {}".format(eps))
133 | if not 0.0 <= betas[0] < 1.0:
134 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
135 | if not 0.0 <= betas[1] < 1.0:
136 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
137 | if not 0.0 <= weight_decay:
138 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
139 |
140 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
141 | super(Adamax, self).__init__(params, defaults)
142 |
143 | def step(self, closure=None):
144 | """Performs a single optimization step.
145 |
146 | Arguments:
147 | closure (callable, optional): A closure that reevaluates the model
148 | and returns the loss.
149 | """
150 | loss = None
151 | if closure is not None:
152 | loss = closure()
153 |
154 | for group in self.param_groups:
155 | for p in group['params']:
156 | if p.grad is None:
157 | continue
158 | grad = p.grad.data
159 | if grad.is_sparse:
160 | raise RuntimeError('Adamax does not support sparse gradients')
161 | state = self.state[p]
162 |
163 | # State initialization
164 | if len(state) == 0:
165 | state['step'] = 0
166 | state['exp_avg'] = torch.zeros_like(p.data)
167 | state['exp_inf'] = torch.zeros_like(p.data)
168 |
169 | exp_avg, exp_inf = state['exp_avg'], state['exp_inf']
170 | beta1, beta2 = group['betas']
171 | eps = group['eps']
172 |
173 | state['step'] += 1
174 |
175 | # Update biased first moment estimate.
176 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
177 | # Update the exponentially weighted infinity norm.
178 | norm_buf = torch.cat([exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)], 0)
179 | torch.max(norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long()))
180 |
181 | bias_correction = 1 - beta1**state['step']
182 | clr = group['lr'] / bias_correction
183 |
184 | p.data.addcdiv_(-clr, exp_avg, exp_inf)
185 |
186 | if group['weight_decay'] != 0:
187 | p.data.add(-clr * group['weight_decay'], p.data)
188 |
189 | return loss
190 |
191 |
192 | class RMSprop(Optimizer):
193 | """Implements RMSprop algorithm.
194 |
195 | Proposed by G. Hinton in his
196 | `course `_.
197 |
198 | The centered version first appears in `Generating Sequences
199 | With Recurrent Neural Networks `_.
200 |
201 | Arguments:
202 | params (iterable): iterable of parameters to optimize or dicts defining
203 | parameter groups
204 | lr (float, optional): learning rate (default: 1e-2)
205 | momentum (float, optional): momentum factor (default: 0)
206 | alpha (float, optional): smoothing constant (default: 0.99)
207 | eps (float, optional): term added to the denominator to improve
208 | numerical stability (default: 1e-8)
209 | centered (bool, optional) : if ``True``, compute the centered RMSProp,
210 | the gradient is normalized by an estimation of its variance
211 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
212 |
213 | """
214 |
215 | def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):
216 | if not 0.0 <= lr:
217 | raise ValueError("Invalid learning rate: {}".format(lr))
218 | if not 0.0 <= eps:
219 | raise ValueError("Invalid epsilon value: {}".format(eps))
220 | if not 0.0 <= momentum:
221 | raise ValueError("Invalid momentum value: {}".format(momentum))
222 | if not 0.0 <= weight_decay:
223 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
224 | if not 0.0 <= alpha:
225 | raise ValueError("Invalid alpha value: {}".format(alpha))
226 |
227 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
228 | super(RMSprop, self).__init__(params, defaults)
229 |
230 | def __setstate__(self, state):
231 | super(RMSprop, self).__setstate__(state)
232 | for group in self.param_groups:
233 | group.setdefault('momentum', 0)
234 | group.setdefault('centered', False)
235 |
236 | def step(self, closure=None):
237 | """Performs a single optimization step.
238 |
239 | Arguments:
240 | closure (callable, optional): A closure that reevaluates the model
241 | and returns the loss.
242 | """
243 | loss = None
244 | if closure is not None:
245 | loss = closure()
246 |
247 | for group in self.param_groups:
248 | for p in group['params']:
249 | if p.grad is None:
250 | continue
251 | grad = p.grad.data
252 | if grad.is_sparse:
253 | raise RuntimeError('RMSprop does not support sparse gradients')
254 | state = self.state[p]
255 |
256 | # State initialization
257 | if len(state) == 0:
258 | state['step'] = 0
259 | state['square_avg'] = torch.zeros_like(p.data)
260 | if group['momentum'] > 0:
261 | state['momentum_buffer'] = torch.zeros_like(p.data)
262 | if group['centered']:
263 | state['grad_avg'] = torch.zeros_like(p.data)
264 |
265 | square_avg = state['square_avg']
266 | alpha = group['alpha']
267 |
268 | state['step'] += 1
269 |
270 | square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad)
271 |
272 | if group['centered']:
273 | grad_avg = state['grad_avg']
274 | grad_avg.mul_(alpha).add_(1 - alpha, grad)
275 | avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt().add_(group['eps'])
276 | else:
277 | avg = square_avg.sqrt().add_(group['eps'])
278 |
279 | if group['momentum'] > 0:
280 | buf = state['momentum_buffer']
281 | buf.mul_(group['momentum']).addcdiv_(grad, avg)
282 | p.data.add_(-group['lr'], buf)
283 | else:
284 | p.data.addcdiv_(-group['lr'], grad, avg)
285 |
286 | if group['weight_decay'] != 0:
287 | p.data.add(-group['lr'] * group['weight_decay'], p.data)
288 |
289 | return loss
290 |
--------------------------------------------------------------------------------
/lib/tabular.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from collections import Counter
4 |
5 | import numpy as np
6 |
7 | import torch
8 |
9 | import pandas
10 |
11 | import h5py
12 |
13 |
14 | class SupervisedDataset(torch.utils.data.Dataset):
15 | def __init__(self, name, role, x, y=None):
16 | if y is None:
17 | y = torch.zeros(x.shape[0]).long()
18 |
19 | assert x.shape[0] == y.shape[0]
20 | assert role in ["train", "valid", "test"]
21 |
22 | self.name = name
23 | self.role = role
24 |
25 | self.x = x
26 | self.y = y
27 |
28 | def __len__(self):
29 | return self.x.shape[0]
30 |
31 | def __getitem__(self, index):
32 | return self.x[index], self.y[index]
33 |
34 | def to(self, device):
35 | return SupervisedDataset(
36 | self.name,
37 | self.role,
38 | self.x.to(device),
39 | self.y.to(device)
40 | )
41 |
42 |
43 | def normalize_raw_data(data, mu, s):
44 | return (data - mu)/s
45 |
46 |
47 | def make_tabular_train_valid_split(data, frac):
48 | n_valid = int(frac*data.shape[0])
49 | valid_data = data[-n_valid:]
50 | train_data = data[0:-n_valid]
51 | return train_data, valid_data
52 |
53 |
54 | def make_tabular_train_valid_test_split(data, frac):
55 | n_test = int(frac*data.shape[0])
56 | test_data = data[-n_test:]
57 | data = data[0:-n_test]
58 |
59 | train_data, valid_data = make_tabular_train_valid_split(data, frac)
60 | return train_data, valid_data, test_data
61 |
62 |
63 | def get_miniboone_raw(data_root):
64 | data = np.load(os.path.join(data_root, "miniboone/data.npy"))
65 |
66 | train_raw, valid_raw, test_raw = make_tabular_train_valid_test_split(data, 0.1)
67 |
68 | data_stack = np.vstack((train_raw, valid_raw))
69 | mu = data_stack.mean(axis=0)
70 | s = data_stack.std(axis=0)
71 |
72 | train_raw = normalize_raw_data(train_raw, mu, s)
73 | valid_raw = normalize_raw_data(valid_raw, mu, s)
74 | test_raw = normalize_raw_data(test_raw, mu, s)
75 |
76 | return train_raw, valid_raw, test_raw
77 |
78 |
79 | def get_gas_raw(data_root):
80 |
81 | def get_gas_correlation_numbers(data):
82 | C = data.corr()
83 | A = C > 0.98
84 | B = A.to_numpy().sum(axis=1)
85 | return B
86 |
87 | data = pandas.read_pickle(os.path.join(data_root, "gas/ethylene_CO.pickle"))
88 | data.drop("Meth", axis=1, inplace=True)
89 | data.drop("Eth", axis=1, inplace=True)
90 | data.drop("Time", axis=1, inplace=True)
91 |
92 | B = get_gas_correlation_numbers(data)
93 | while np.any(B > 1):
94 | col_to_remove = np.where(B > 1)[0][0]
95 | col_name = data.columns[col_to_remove]
96 | data.drop(col_name, axis=1, inplace=True)
97 | B = get_gas_correlation_numbers(data)
98 |
99 | data = normalize_raw_data(data, data.mean(), data.std()).to_numpy()
100 | return make_tabular_train_valid_test_split(data, 0.1)
101 |
102 |
103 | def get_hepmass_raw(data_root):
104 | train_data_path = os.path.join(data_root, "hepmass/1000_train.csv")
105 | test_data_path = os.path.join(data_root, "hepmass/1000_test.csv")
106 |
107 | train_raw = pandas.read_csv(filepath_or_buffer=train_data_path, index_col=False)
108 | test_raw = pandas.read_csv(filepath_or_buffer=test_data_path, index_col=False)
109 |
110 | train_raw = train_raw[train_raw[train_raw.columns[0]] == 1]
111 | train_raw = train_raw.drop(train_raw.columns[0], axis=1)
112 |
113 | test_raw = test_raw[test_raw[test_raw.columns[0]] == 1]
114 | test_raw = test_raw.drop(test_raw.columns[0], axis=1)
115 | test_raw = test_raw.drop(test_raw.columns[-1], axis=1)
116 |
117 | mu = train_raw.mean()
118 | s = train_raw.std()
119 | train_raw = normalize_raw_data(train_raw, mu, s).to_numpy()
120 | test_raw = normalize_raw_data(test_raw, mu, s).to_numpy()
121 |
122 | i = 0
123 | features_to_remove = []
124 | for feature in train_raw.T:
125 | c = Counter(feature)
126 | max_count = np.array([v for k, v in sorted(c.items())])[0]
127 | if max_count > 5:
128 | features_to_remove.append(i)
129 | i += 1
130 | train_raw = train_raw[:, np.array([i for i in range(train_raw.shape[1]) if i not in features_to_remove])]
131 | test_raw = test_raw[:, np.array([i for i in range(test_raw.shape[1]) if i not in features_to_remove])]
132 |
133 | train_raw, valid_raw = make_tabular_train_valid_split(train_raw, 0.1)
134 | return train_raw, valid_raw, test_raw
135 |
136 |
137 | def get_power_raw(data_root):
138 | data = np.load(os.path.join(data_root, "power/data.npy"))
139 | np.random.shuffle(data)
140 | n = data.shape[0]
141 |
142 | data = np.delete(data, 3, axis=1)
143 | data = np.delete(data, 1, axis=1)
144 |
145 | gap_noise = 0.001*np.random.rand(n, 1)
146 | voltage_noise = 0.01*np.random.rand(n, 1)
147 | sm_noise = np.random.rand(n, 3)
148 | time_noise = np.zeros((n, 1))
149 |
150 | noise = np.hstack((gap_noise, voltage_noise, sm_noise, time_noise))
151 | data = data + noise
152 |
153 | train_raw, valid_raw, test_raw = make_tabular_train_valid_test_split(data, 0.1)
154 |
155 | train_and_valid = np.vstack((train_raw, valid_raw))
156 | mu = train_and_valid.mean(axis=0)
157 | s = train_and_valid.std(axis=0)
158 |
159 | train_raw = normalize_raw_data(train_raw, mu, s)
160 | valid_raw = normalize_raw_data(valid_raw, mu, s)
161 | test_raw = normalize_raw_data(test_raw, mu, s)
162 |
163 | return train_raw, valid_raw, test_raw
164 |
165 |
166 | def get_bsds300_raw(data_root):
167 | with h5py.File(os.path.join(data_root, "BSDS300", "BSDS300.hdf5"), "r") as f:
168 | train_raw = f["train"][()]
169 | valid_raw = f["validation"][()]
170 | test_raw = f["test"][()]
171 | return train_raw, valid_raw, test_raw
172 |
173 |
174 | def get_raw_tabular_datasets(name, data_root):
175 | if name == "miniboone":
176 | data_fn = get_miniboone_raw
177 | elif name == "gas":
178 | data_fn = get_gas_raw
179 | elif name == "hepmass":
180 | data_fn = get_hepmass_raw
181 | elif name == "power":
182 | data_fn = get_power_raw
183 | elif name == "bsds300":
184 | data_fn = get_bsds300_raw
185 | else:
186 | raise NotImplementedError
187 |
188 | return data_fn(data_root)
189 |
190 |
191 | def get_tabular_datasets(name, data_root):
192 | train_raw, valid_raw, test_raw = get_raw_tabular_datasets(name, data_root)
193 | print(train_raw.shape)
194 |
195 | train_dset = SupervisedDataset(name, "train",
196 | torch.tensor(train_raw, dtype=torch.get_default_dtype()))
197 | valid_dset = SupervisedDataset(name, "valid",
198 | torch.tensor(valid_raw, dtype=torch.get_default_dtype()))
199 | test_dset = SupervisedDataset(name, "test",
200 | torch.tensor(test_raw, dtype=torch.get_default_dtype()))
201 |
202 | return train_dset, valid_dset, test_dset
--------------------------------------------------------------------------------
/lib/toy_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sklearn
3 | import sklearn.datasets
4 | from sklearn.utils import shuffle as util_shuffle
5 |
6 |
7 | # Dataset iterator
8 | def inf_train_gen(data, batch_size=200):
9 |
10 | if data == "swissroll":
11 | data = sklearn.datasets.make_swiss_roll(n_samples=batch_size, noise=1.0)[0]
12 | data = data.astype("float32")[:, [0, 2]]
13 | data /= 5
14 | return data
15 |
16 | elif data == "circles":
17 | data = sklearn.datasets.make_circles(n_samples=batch_size, factor=.5, noise=0.08)[0]
18 | data = data.astype("float32")
19 | data *= 3
20 | return data
21 |
22 | elif data == "rings":
23 | n_samples4 = n_samples3 = n_samples2 = batch_size // 4
24 | n_samples1 = batch_size - n_samples4 - n_samples3 - n_samples2
25 |
26 | # so as not to have the first point = last point, we set endpoint=False
27 | linspace4 = np.linspace(0, 2 * np.pi, n_samples4, endpoint=False)
28 | linspace3 = np.linspace(0, 2 * np.pi, n_samples3, endpoint=False)
29 | linspace2 = np.linspace(0, 2 * np.pi, n_samples2, endpoint=False)
30 | linspace1 = np.linspace(0, 2 * np.pi, n_samples1, endpoint=False)
31 |
32 | circ4_x = np.cos(linspace4)
33 | circ4_y = np.sin(linspace4)
34 | circ3_x = np.cos(linspace4) * 0.75
35 | circ3_y = np.sin(linspace3) * 0.75
36 | circ2_x = np.cos(linspace2) * 0.5
37 | circ2_y = np.sin(linspace2) * 0.5
38 | circ1_x = np.cos(linspace1) * 0.25
39 | circ1_y = np.sin(linspace1) * 0.25
40 |
41 | X = np.vstack([
42 | np.hstack([circ4_x, circ3_x, circ2_x, circ1_x]),
43 | np.hstack([circ4_y, circ3_y, circ2_y, circ1_y])
44 | ]).T * 3.0
45 | X = util_shuffle(X)
46 |
47 | # Add noise
48 | X = X + np.random.normal(scale=0.08, size=X.shape)
49 |
50 | return X.astype("float32")
51 |
52 | elif data == "moons":
53 | data = sklearn.datasets.make_moons(n_samples=batch_size, noise=0.1)[0]
54 | data = data.astype("float32")
55 | data = data * 2 + np.array([-1, -0.2])
56 | return data
57 |
58 | elif data == "8gaussians":
59 | scale = 4.
60 | centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)),
61 | (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2),
62 | 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))]
63 | centers = [(scale * x, scale * y) for x, y in centers]
64 |
65 | dataset = []
66 | for i in range(batch_size):
67 | point = np.random.randn(2) * 0.5
68 | idx = np.random.randint(8)
69 | center = centers[idx]
70 | point[0] += center[0]
71 | point[1] += center[1]
72 | dataset.append(point)
73 | dataset = np.array(dataset, dtype="float32")
74 | dataset /= 1.414
75 | return dataset
76 |
77 | elif data == "pinwheel":
78 | radial_std = 0.3
79 | tangential_std = 0.1
80 | num_classes = 5
81 | num_per_class = batch_size // 5
82 | rate = 0.25
83 | rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False)
84 |
85 | features = np.random.randn(num_classes*num_per_class, 2) \
86 | * np.array([radial_std, tangential_std])
87 | features[:, 0] += 1.
88 | labels = np.repeat(np.arange(num_classes), num_per_class)
89 |
90 | angles = rads[labels] + rate * np.exp(features[:, 0])
91 | rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)])
92 | rotations = np.reshape(rotations.T, (-1, 2, 2))
93 |
94 | return 2 * np.random.permutation(np.einsum("ti,tij->tj", features, rotations))
95 |
96 | elif data == "2spirals":
97 | n = np.sqrt(np.random.rand(batch_size // 2, 1)) * 540 * (2 * np.pi) / 360
98 | d1x = -np.cos(n) * n + np.random.rand(batch_size // 2, 1) * 0.5
99 | d1y = np.sin(n) * n + np.random.rand(batch_size // 2, 1) * 0.5
100 | x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3
101 | x += np.random.randn(*x.shape) * 0.1
102 | return x
103 |
104 | elif data == "checkerboard":
105 | x1 = np.random.rand(batch_size) * 4 - 2
106 | x2_ = np.random.rand(batch_size) - np.random.randint(0, 2, batch_size) * 2
107 | x2 = x2_ + (np.floor(x1) % 2)
108 | return np.concatenate([x1[:, None], x2[:, None]], 1) * 2
109 |
110 | elif data == "line":
111 | x = np.random.rand(batch_size) * 5 - 2.5
112 | y = x
113 | return np.stack((x, y), 1)
114 | elif data == "cos":
115 | x = np.random.rand(batch_size) * 5 - 2.5
116 | y = np.sin(x) * 2.5
117 | return np.stack((x, y), 1)
118 | else:
119 | return inf_train_gen("8gaussians", batch_size)
120 |
--------------------------------------------------------------------------------
/lib/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | from numbers import Number
4 | import logging
5 | import torch
6 |
7 |
8 | def makedirs(dirname):
9 | if not os.path.exists(dirname):
10 | os.makedirs(dirname)
11 |
12 |
13 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
14 | logger = logging.getLogger()
15 | if debug:
16 | level = logging.DEBUG
17 | else:
18 | level = logging.INFO
19 | logger.setLevel(level)
20 | if saving:
21 | info_file_handler = logging.FileHandler(logpath, mode="a")
22 | info_file_handler.setLevel(level)
23 | logger.addHandler(info_file_handler)
24 | if displaying:
25 | console_handler = logging.StreamHandler()
26 | console_handler.setLevel(level)
27 | logger.addHandler(console_handler)
28 | logger.info(filepath)
29 | with open(filepath, "r") as f:
30 | logger.info(f.read())
31 |
32 | for f in package_files:
33 | logger.info(f)
34 | with open(f, "r") as package_f:
35 | logger.info(package_f.read())
36 |
37 | return logger
38 |
39 |
40 | class AverageMeter(object):
41 | """Computes and stores the average and current value"""
42 |
43 | def __init__(self):
44 | self.reset()
45 |
46 | def reset(self):
47 | self.val = 0
48 | self.avg = 0
49 | self.sum = 0
50 | self.count = 0
51 |
52 | def update(self, val, n=1):
53 | self.val = val
54 | self.sum += val * n
55 | self.count += n
56 | self.avg = self.sum / self.count
57 |
58 |
59 | class RunningAverageMeter(object):
60 | """Computes and stores the average and current value"""
61 |
62 | def __init__(self, momentum=0.99):
63 | self.momentum = momentum
64 | self.reset()
65 |
66 | def reset(self):
67 | self.val = None
68 | self.avg = 0
69 |
70 | def update(self, val):
71 | if self.val is None:
72 | self.avg = val
73 | else:
74 | self.avg = self.avg * self.momentum + val * (1 - self.momentum)
75 | self.val = val
76 |
77 |
78 | def inf_generator(iterable):
79 | """Allows training with DataLoaders in a single infinite loop:
80 | for i, (x, y) in enumerate(inf_generator(train_loader)):
81 | """
82 | iterator = iterable.__iter__()
83 | while True:
84 | try:
85 | yield iterator.__next__()
86 | except StopIteration:
87 | iterator = iterable.__iter__()
88 |
89 |
90 | def save_checkpoint(state, save, epoch, last_checkpoints=None, num_checkpoints=None):
91 | if not os.path.exists(save):
92 | os.makedirs(save)
93 | filename = os.path.join(save, 'checkpt-%04d.pth' % epoch)
94 | torch.save(state, filename)
95 |
96 | if last_checkpoints is not None and num_checkpoints is not None:
97 | last_checkpoints.append(epoch)
98 | if len(last_checkpoints) > num_checkpoints:
99 | rm_epoch = last_checkpoints.pop(0)
100 | os.remove(os.path.join(save, 'checkpt-%04d.pth' % rm_epoch))
101 |
102 |
103 | def isnan(tensor):
104 | return (tensor != tensor)
105 |
106 |
107 | def logsumexp(value, dim=None, keepdim=False):
108 | """Numerically stable implementation of the operation
109 | value.exp().sum(dim, keepdim).log()
110 | """
111 | if dim is not None:
112 | m, _ = torch.max(value, dim=dim, keepdim=True)
113 | value0 = value - m
114 | if keepdim is False:
115 | m = m.squeeze(dim)
116 | return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim))
117 | else:
118 | m = torch.max(value)
119 | sum_exp = torch.sum(torch.exp(value - m))
120 | if isinstance(sum_exp, Number):
121 | return m + math.log(sum_exp)
122 | else:
123 | return m + torch.log(sum_exp)
124 |
125 |
126 | class ExponentialMovingAverage(object):
127 |
128 | def __init__(self, module, decay=0.999):
129 | """Initializes the model when .apply() is called the first time.
130 | This is to take into account data-dependent initialization that occurs in the first iteration."""
131 | self.module = module
132 | self.decay = decay
133 | self.shadow_params = {}
134 | self.nparams = sum(p.numel() for p in module.parameters())
135 |
136 | def init(self):
137 | for name, param in self.module.named_parameters():
138 | self.shadow_params[name] = param.data.clone()
139 |
140 | def apply(self):
141 | if len(self.shadow_params) == 0:
142 | self.init()
143 | else:
144 | with torch.no_grad():
145 | for name, param in self.module.named_parameters():
146 | self.shadow_params[name] -= (1 - self.decay) * (self.shadow_params[name] - param.data)
147 |
148 | def set(self, other_ema):
149 | self.init()
150 | with torch.no_grad():
151 | for name, param in other_ema.shadow_params.items():
152 | self.shadow_params[name].copy_(param)
153 |
154 | def replace_with_ema(self):
155 | for name, param in self.module.named_parameters():
156 | param.data.copy_(self.shadow_params[name])
157 |
158 | def swap(self):
159 | for name, param in self.module.named_parameters():
160 | tmp = self.shadow_params[name].clone()
161 | self.shadow_params[name].copy_(param.data)
162 | param.data.copy_(tmp)
163 |
164 | def __repr__(self):
165 | return (
166 | '{}(decay={}, module={}, nparams={})'.format(
167 | self.__class__.__name__, self.decay, self.module.__class__.__name__, self.nparams
168 | )
169 | )
170 |
--------------------------------------------------------------------------------
/lib/visualize_flow.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | matplotlib.use("Agg")
4 | import matplotlib.pyplot as plt
5 | import torch
6 |
7 | LOW = -4
8 | HIGH = 4
9 |
10 |
11 | def plt_potential_func(potential, ax, npts=100, title="$p(x)$"):
12 | """
13 | Args:
14 | potential: computes U(z_k) given z_k
15 | """
16 | xside = np.linspace(LOW, HIGH, npts)
17 | yside = np.linspace(LOW, HIGH, npts)
18 | xx, yy = np.meshgrid(xside, yside)
19 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])
20 |
21 | z = torch.Tensor(z)
22 | u = potential(z).cpu().numpy()
23 | p = np.exp(-u).reshape(npts, npts)
24 |
25 | plt.pcolormesh(xx, yy, p)
26 | ax.invert_yaxis()
27 | ax.get_xaxis().set_ticks([])
28 | ax.get_yaxis().set_ticks([])
29 | ax.set_title(title)
30 |
31 |
32 | def plt_flow(prior_logdensity, transform, ax, npts=100, title="$q(x)$", device="cpu"):
33 | """
34 | Args:
35 | transform: computes z_k and log(q_k) given z_0
36 | """
37 | side = np.linspace(LOW, HIGH, npts)
38 | xx, yy = np.meshgrid(side, side)
39 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])
40 |
41 | z = torch.tensor(z, requires_grad=True).type(torch.float32).to(device)
42 | logqz = prior_logdensity(z)
43 | logqz = torch.sum(logqz, dim=1)[:, None]
44 | z, logqz = transform(z, logqz)
45 | logqz = torch.sum(logqz, dim=1)[:, None]
46 |
47 | xx = z[:, 0].cpu().numpy().reshape(npts, npts)
48 | yy = z[:, 1].cpu().numpy().reshape(npts, npts)
49 | qz = np.exp(logqz.cpu().numpy()).reshape(npts, npts)
50 |
51 | plt.pcolormesh(xx, yy, qz)
52 | ax.set_xlim(LOW, HIGH)
53 | ax.set_ylim(LOW, HIGH)
54 | cmap = matplotlib.cm.get_cmap(None)
55 | ax.set_facecolor(cmap(0.))
56 | ax.invert_yaxis()
57 | ax.get_xaxis().set_ticks([])
58 | ax.get_yaxis().set_ticks([])
59 | ax.set_title(title)
60 |
61 |
62 | def plt_flow_density(prior_logdensity, inverse_transform, ax, npts=100, memory=100, title="$q(x)$", device="cpu"):
63 | side = np.linspace(LOW, HIGH, npts)
64 | xx, yy = np.meshgrid(side, side)
65 | x = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])
66 |
67 | x = torch.from_numpy(x).type(torch.float32).to(device)
68 | zeros = torch.zeros(x.shape[0], 1).to(x)
69 |
70 | z, delta_logp = [], []
71 | inds = torch.arange(0, x.shape[0]).to(torch.int64)
72 | for ii in torch.split(inds, int(memory**2)):
73 | z_, delta_logp_ = inverse_transform(x[ii], zeros[ii])
74 | z.append(z_)
75 | delta_logp.append(delta_logp_)
76 | z = torch.cat(z, 0)
77 | delta_logp = torch.cat(delta_logp, 0)
78 |
79 | logpz = prior_logdensity(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z)
80 | logpx = logpz - delta_logp
81 |
82 | px = np.exp(logpx.cpu().numpy()).reshape(npts, npts)
83 |
84 | ax.imshow(px, cmap='inferno')
85 | ax.get_xaxis().set_ticks([])
86 | ax.get_yaxis().set_ticks([])
87 | ax.set_title(title)
88 |
89 |
90 | def plt_flow_samples(prior_sample, transform, ax, npts=100, memory=100, title="$x ~ q(x)$", device="cpu"):
91 | z = prior_sample(npts * npts, 2).type(torch.float32).to(device)
92 | zk = []
93 | inds = torch.arange(0, z.shape[0]).to(torch.int64)
94 | for ii in torch.split(inds, int(memory**2)):
95 | zk.append(transform(z[ii]))
96 | zk = torch.cat(zk, 0).cpu().numpy()
97 | ax.hist2d(zk[:, 0], zk[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts, cmap='inferno')
98 | ax.invert_yaxis()
99 | ax.get_xaxis().set_ticks([])
100 | ax.get_yaxis().set_ticks([])
101 | ax.set_title(title)
102 |
103 |
104 | def plt_samples(samples, ax, npts=100, title="$x ~ p(x)$"):
105 | ax.hist2d(samples[:, 0], samples[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts, cmap='inferno')
106 | ax.invert_yaxis()
107 | ax.get_xaxis().set_ticks([])
108 | ax.get_yaxis().set_ticks([])
109 | ax.set_title(title)
110 |
111 |
112 | def visualize_transform(
113 | potential_or_samples, prior_sample, prior_density, transform=None, inverse_transform=None, samples=True, npts=100,
114 | memory=100, device="cpu"
115 | ):
116 | """Produces visualization for the model density and samples from the model."""
117 | plt.clf()
118 | ax = plt.subplot(1, 3, 1, aspect="equal")
119 | if samples:
120 | plt_samples(potential_or_samples, ax, npts=npts)
121 | else:
122 | plt_potential_func(potential_or_samples, ax, npts=npts)
123 |
124 | ax = plt.subplot(1, 3, 2, aspect="equal")
125 | if inverse_transform is None:
126 | plt_flow(prior_density, transform, ax, npts=npts, device=device)
127 | else:
128 | plt_flow_density(prior_density, inverse_transform, ax, npts=npts, memory=memory, device=device)
129 |
130 | ax = plt.subplot(1, 3, 3, aspect="equal")
131 | if transform is not None:
132 | plt_flow_samples(prior_sample, transform, ax, npts=npts, memory=memory, device=device)
133 |
--------------------------------------------------------------------------------
/preprocessing/convert_to_pth.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import re
3 | import numpy as np
4 | import torch
5 |
6 | img = torch.tensor(np.load(sys.argv[1]))
7 | img = img.permute(0, 3, 1, 2)
8 | torch.save(img, re.sub('.npy$', '.pth', sys.argv[1]))
9 |
--------------------------------------------------------------------------------
/preprocessing/create_imagenet_benchmark_datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Run the following commands in ~ before running this file
3 | wget http://image-net.org/small/train_64x64.tar
4 | wget http://image-net.org/small/valid_64x64.tar
5 | tar -xvf train_64x64.tar
6 | tar -xvf valid_64x64.tar
7 | wget http://image-net.org/small/train_32x32.tar
8 | wget http://image-net.org/small/valid_32x32.tar
9 | tar -xvf train_32x32.tar
10 | tar -xvf valid_32x32.tar
11 | """
12 |
13 | import numpy as np
14 | import scipy.ndimage
15 | import os
16 | from os import listdir
17 | from os.path import isfile, join
18 | from tqdm import tqdm
19 |
20 |
21 | def convert_path_to_npy(*, path='train_64x64', outfile='train_64x64.npy'):
22 | assert isinstance(path, str), "Expected a string input for the path"
23 | assert os.path.exists(path), "Input path doesn't exist"
24 | files = [f for f in listdir(path) if isfile(join(path, f))]
25 | print('Number of valid images is:', len(files))
26 | imgs = []
27 | for i in tqdm(range(len(files))):
28 | img = scipy.ndimage.imread(join(path, files[i]))
29 | img = img.astype('uint8')
30 | assert np.max(img) <= 255
31 | assert np.min(img) >= 0
32 | assert img.dtype == 'uint8'
33 | assert isinstance(img, np.ndarray)
34 | imgs.append(img)
35 | resolution_x, resolution_y = img.shape[0], img.shape[1]
36 | imgs = np.asarray(imgs).astype('uint8')
37 | assert imgs.shape[1:] == (resolution_x, resolution_y, 3)
38 | assert np.max(imgs) <= 255
39 | assert np.min(imgs) >= 0
40 | print('Total number of images is:', imgs.shape[0])
41 | print('All assertions done, dumping into npy file')
42 | np.save(outfile, imgs)
43 |
44 |
45 | if __name__ == '__main__':
46 | convert_path_to_npy(path='train_64x64', outfile='train_64x64.npy')
47 | convert_path_to_npy(path='valid_64x64', outfile='valid_64x64.npy')
48 | convert_path_to_npy(path='train_32x32', outfile='train_32x32.npy')
49 | convert_path_to_npy(path='valid_32x32', outfile='valid_32x32.npy')
50 |
--------------------------------------------------------------------------------
/preprocessing/extract_celeba_from_tfrecords.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import torch
4 |
5 | sess = tf.InteractiveSession()
6 |
7 | train_imgs = []
8 |
9 | print('Reading from training set...', flush=True)
10 | for i in range(120):
11 | tfr = 'data/celebahq/celeba-tfr/train/train-r08-s-{:04d}-of-0120.tfrecords'.format(i)
12 | print(tfr, flush=True)
13 |
14 | record_iterator = tf.python_io.tf_record_iterator(tfr)
15 |
16 | for string_record in record_iterator:
17 | example = tf.train.Example()
18 | example.ParseFromString(string_record)
19 |
20 | image_bytes = example.features.feature['data'].bytes_list.value[0]
21 |
22 | img = tf.decode_raw(image_bytes, tf.uint8)
23 | img = tf.reshape(img, [256, 256, 3])
24 | img = img.eval()
25 |
26 | train_imgs.append(img)
27 |
28 | train_imgs = np.stack(train_imgs)
29 | train_imgs = torch.tensor(train_imgs).permute(0, 3, 1, 2)
30 | torch.save(train_imgs, 'data/celebahq/celeba256_train.pth')
31 |
32 | validation_imgs = []
33 | for i in range(40):
34 | tfr = 'data/celebahq/celeba-tfr/validation/validation-r08-s-{:04d}-of-0040.tfrecords'.format(i)
35 | print(tfr, flush=True)
36 |
37 | record_iterator = tf.python_io.tf_record_iterator(tfr)
38 |
39 | for string_record in record_iterator:
40 | example = tf.train.Example()
41 | example.ParseFromString(string_record)
42 |
43 | image_bytes = example.features.feature['data'].bytes_list.value[0]
44 |
45 | img = tf.decode_raw(image_bytes, tf.uint8)
46 | img = tf.reshape(img, [256, 256, 3])
47 | img = img.eval()
48 |
49 | validation_imgs.append(img)
50 |
51 | validation_imgs = np.stack(validation_imgs)
52 | validation_imgs = torch.tensor(validation_imgs).permute(0, 3, 1, 2)
53 | torch.save(validation_imgs, 'data/celebahq/celeba256_validation.pth')
54 |
--------------------------------------------------------------------------------
/qualitative_samples.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | from tqdm import tqdm
4 |
5 | import torch
6 | import torchvision.transforms as transforms
7 | from torchvision.utils import save_image
8 | import torchvision.datasets as vdsets
9 |
10 | from lib.iresnet import ACT_FNS, ResidualFlow
11 | import lib.datasets as datasets
12 | import lib.utils as utils
13 | import lib.layers as layers
14 | import lib.layers.base as base_layers
15 |
16 | # Arguments
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument(
19 | '--data', type=str, default='cifar10', choices=[
20 | 'mnist',
21 | 'cifar10',
22 | 'celeba',
23 | 'celebahq',
24 | 'celeba_5bit',
25 | 'imagenet32',
26 | 'imagenet64',
27 | ]
28 | )
29 | parser.add_argument('--dataroot', type=str, default='data')
30 | parser.add_argument('--imagesize', type=int, default=32)
31 | parser.add_argument('--nbits', type=int, default=8) # Only used for celebahq.
32 |
33 | # Sampling parameters.
34 | parser.add_argument('--real', type=eval, choices=[True, False], default=False)
35 | parser.add_argument('--nrow', type=int, default=10)
36 | parser.add_argument('--ncol', type=int, default=10)
37 | parser.add_argument('--temp', type=float, default=1.0)
38 | parser.add_argument('--nbatches', type=int, default=5)
39 | parser.add_argument('--save-each', type=eval, choices=[True, False], default=False)
40 |
41 | parser.add_argument('--block', type=str, choices=['resblock', 'coupling'], default='resblock')
42 |
43 | parser.add_argument('--coeff', type=float, default=0.98)
44 | parser.add_argument('--vnorms', type=str, default='2222')
45 | parser.add_argument('--n-lipschitz-iters', type=int, default=None)
46 | parser.add_argument('--sn-tol', type=float, default=1e-3)
47 | parser.add_argument('--learn-p', type=eval, choices=[True, False], default=False)
48 |
49 | parser.add_argument('--n-power-series', type=int, default=None)
50 | parser.add_argument('--factor-out', type=eval, choices=[True, False], default=False)
51 | parser.add_argument('--n-dist', choices=['geometric', 'poisson'], default='geometric')
52 | parser.add_argument('--n-samples', type=int, default=1)
53 | parser.add_argument('--n-exact-terms', type=int, default=2)
54 | parser.add_argument('--var-reduc-lr', type=float, default=0)
55 | parser.add_argument('--neumann-grad', type=eval, choices=[True, False], default=True)
56 | parser.add_argument('--mem-eff', type=eval, choices=[True, False], default=True)
57 |
58 | parser.add_argument('--act', type=str, choices=ACT_FNS.keys(), default='swish')
59 | parser.add_argument('--idim', type=int, default=512)
60 | parser.add_argument('--nblocks', type=str, default='16-16-16')
61 | parser.add_argument('--squeeze-first', type=eval, default=False, choices=[True, False])
62 | parser.add_argument('--actnorm', type=eval, default=True, choices=[True, False])
63 | parser.add_argument('--fc-actnorm', type=eval, default=False, choices=[True, False])
64 | parser.add_argument('--batchnorm', type=eval, default=False, choices=[True, False])
65 | parser.add_argument('--dropout', type=float, default=0.)
66 | parser.add_argument('--fc', type=eval, default=False, choices=[True, False])
67 | parser.add_argument('--kernels', type=str, default='3-1-3')
68 | parser.add_argument('--quadratic', type=eval, choices=[True, False], default=False)
69 | parser.add_argument('--fc-end', type=eval, choices=[True, False], default=True)
70 | parser.add_argument('--fc-idim', type=int, default=128)
71 | parser.add_argument('--preact', type=eval, choices=[True, False], default=True)
72 | parser.add_argument('--padding', type=int, default=0)
73 | parser.add_argument('--first-resblock', type=eval, choices=[True, False], default=True)
74 |
75 | parser.add_argument('--task', type=str, choices=['density'], default='density')
76 | parser.add_argument('--rcrop-pad-mode', type=str, choices=['constant', 'reflect'], default='reflect')
77 | parser.add_argument('--padding-dist', type=str, choices=['uniform', 'gaussian'], default='uniform')
78 |
79 | parser.add_argument('--resume', type=str, required=True)
80 | parser.add_argument('--nworkers', type=int, default=4)
81 | args = parser.parse_args()
82 |
83 | W = args.ncol
84 | H = args.nrow
85 |
86 | args.batchsize = W * H
87 | args.val_batchsize = W * H
88 |
89 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
90 |
91 | if device.type == 'cuda':
92 | print('Found {} CUDA devices.'.format(torch.cuda.device_count()))
93 | for i in range(torch.cuda.device_count()):
94 | props = torch.cuda.get_device_properties(i)
95 | print('{} \t Memory: {:.2f}GB'.format(props.name, props.total_memory / (1024**3)))
96 | else:
97 | print('WARNING: Using device {}'.format(device))
98 |
99 |
100 | def geometric_logprob(ns, p):
101 | return torch.log(1 - p + 1e-10) * (ns - 1) + torch.log(p + 1e-10)
102 |
103 |
104 | def standard_normal_sample(size):
105 | return torch.randn(size)
106 |
107 |
108 | def standard_normal_logprob(z):
109 | logZ = -0.5 * math.log(2 * math.pi)
110 | return logZ - z.pow(2) / 2
111 |
112 |
113 | def count_parameters(model):
114 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
115 |
116 |
117 | def add_noise(x, nvals=255):
118 | """
119 | [0, 1] -> [0, nvals] -> add noise -> [0, 1]
120 | """
121 | noise = x.new().resize_as_(x).uniform_()
122 | x = x * nvals + noise
123 | x = x / (nvals + 1)
124 | return x
125 |
126 |
127 | def update_lr(optimizer, itr):
128 | iter_frac = min(float(itr + 1) / max(args.warmup_iters, 1), 1.0)
129 | lr = args.lr * iter_frac
130 | for param_group in optimizer.param_groups:
131 | param_group["lr"] = lr
132 |
133 |
134 | def add_padding(x):
135 | if args.padding > 0:
136 | u = x.new_empty(x.shape[0], args.padding, x.shape[2], x.shape[3]).uniform_()
137 | logpu = torch.zeros_like(u).sum([1, 2, 3])
138 | return torch.cat([u, x], dim=1), logpu
139 | else:
140 | return x, torch.zeros(x.shape[0]).to(x)
141 |
142 |
143 | def remove_padding(x):
144 | if args.padding > 0:
145 | return x[:, args.padding:, :, :]
146 | else:
147 | return x
148 |
149 |
150 | def reduce_bits(x):
151 | if args.nbits < 8:
152 | x = x * 255
153 | x = torch.floor(x / 2**(8 - args.nbits))
154 | x = x / 2**args.nbits
155 | return x
156 |
157 |
158 | def update_lipschitz(model):
159 | for m in model.modules():
160 | if isinstance(m, base_layers.SpectralNormConv2d) or isinstance(m, base_layers.SpectralNormLinear):
161 | m.compute_weight(update=True)
162 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
163 | m.compute_weight(update=True)
164 |
165 |
166 | print('Loading dataset {}'.format(args.data), flush=True)
167 | # Dataset and hyperparameters
168 | if args.data == 'cifar10':
169 | im_dim = 3
170 | n_classes = 10
171 |
172 | if args.task in ['classification', 'hybrid']:
173 |
174 | if args.real:
175 |
176 | # Classification-specific preprocessing.
177 | transform_train = transforms.Compose([
178 | transforms.Resize(args.imagesize),
179 | transforms.RandomCrop(32, padding=4, padding_mode=args.rcrop_pad_mode),
180 | transforms.RandomHorizontalFlip(),
181 | transforms.ToTensor(),
182 | add_noise,
183 | ])
184 |
185 | transform_test = transforms.Compose([
186 | transforms.Resize(args.imagesize),
187 | transforms.ToTensor(),
188 | add_noise,
189 | ])
190 |
191 | # Remove the logit transform.
192 | init_layer = layers.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
193 | else:
194 | if args.real:
195 | transform_train = transforms.Compose([
196 | transforms.Resize(args.imagesize),
197 | transforms.RandomHorizontalFlip(),
198 | transforms.ToTensor(),
199 | add_noise,
200 | ])
201 | transform_test = transforms.Compose([
202 | transforms.Resize(args.imagesize),
203 | transforms.ToTensor(),
204 | add_noise,
205 | ])
206 | init_layer = layers.LogitTransform(0.05)
207 | if args.real:
208 | train_loader = torch.utils.data.DataLoader(
209 | datasets.CIFAR10(args.dataroot, train=True, transform=transform_train),
210 | batch_size=args.batchsize,
211 | shuffle=True,
212 | num_workers=args.nworkers,
213 | )
214 | test_loader = torch.utils.data.DataLoader(
215 | datasets.CIFAR10(args.dataroot, train=False, transform=transform_test),
216 | batch_size=args.val_batchsize,
217 | shuffle=False,
218 | num_workers=args.nworkers,
219 | )
220 | elif args.data == 'mnist':
221 | im_dim = 1
222 | init_layer = layers.LogitTransform(1e-6)
223 | n_classes = 10
224 |
225 | if args.real:
226 | train_loader = torch.utils.data.DataLoader(
227 | datasets.MNIST(
228 | args.dataroot, train=True, transform=transforms.Compose([
229 | transforms.Resize(args.imagesize),
230 | transforms.ToTensor(),
231 | add_noise,
232 | ])
233 | ),
234 | batch_size=args.batchsize,
235 | shuffle=True,
236 | num_workers=args.nworkers,
237 | )
238 | test_loader = torch.utils.data.DataLoader(
239 | datasets.MNIST(
240 | args.dataroot, train=False, transform=transforms.Compose([
241 | transforms.Resize(args.imagesize),
242 | transforms.ToTensor(),
243 | add_noise,
244 | ])
245 | ),
246 | batch_size=args.val_batchsize,
247 | shuffle=False,
248 | num_workers=args.nworkers,
249 | )
250 | elif args.data == 'svhn':
251 | im_dim = 3
252 | init_layer = layers.LogitTransform(0.05)
253 | n_classes = 10
254 |
255 | if args.real:
256 | train_loader = torch.utils.data.DataLoader(
257 | vdsets.SVHN(
258 | args.dataroot, split='train', download=True, transform=transforms.Compose([
259 | transforms.Resize(args.imagesize),
260 | transforms.RandomCrop(32, padding=4, padding_mode=args.rcrop_pad_mode),
261 | transforms.ToTensor(),
262 | add_noise,
263 | ])
264 | ),
265 | batch_size=args.batchsize,
266 | shuffle=True,
267 | num_workers=args.nworkers,
268 | )
269 | test_loader = torch.utils.data.DataLoader(
270 | vdsets.SVHN(
271 | args.dataroot, split='test', download=True, transform=transforms.Compose([
272 | transforms.Resize(args.imagesize),
273 | transforms.ToTensor(),
274 | add_noise,
275 | ])
276 | ),
277 | batch_size=args.val_batchsize,
278 | shuffle=False,
279 | num_workers=args.nworkers,
280 | )
281 | elif args.data == 'celebahq':
282 | im_dim = 3
283 | init_layer = layers.LogitTransform(0.05)
284 |
285 | if args.real:
286 | train_loader = torch.utils.data.DataLoader(
287 | datasets.CelebAHQ(
288 | train=True, transform=transforms.Compose([
289 | transforms.ToPILImage(),
290 | transforms.RandomHorizontalFlip(),
291 | transforms.ToTensor(),
292 | reduce_bits,
293 | lambda x: add_noise(x, nvals=2**args.nbits),
294 | ])
295 | ), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers
296 | )
297 | test_loader = torch.utils.data.DataLoader(
298 | datasets.CelebAHQ(
299 | train=False, transform=transforms.Compose([
300 | reduce_bits,
301 | lambda x: add_noise(x, nvals=2**args.nbits),
302 | ])
303 | ), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers
304 | )
305 | elif args.data == 'celeba_5bit':
306 | im_dim = 3
307 | init_layer = layers.LogitTransform(0.05)
308 | if args.imagesize != 64:
309 | print('Changing image size to 64.')
310 | args.imagesize = 64
311 |
312 | if args.real:
313 | train_loader = torch.utils.data.DataLoader(
314 | datasets.CelebA5bit(
315 | train=True, transform=transforms.Compose([
316 | transforms.ToPILImage(),
317 | transforms.RandomHorizontalFlip(),
318 | transforms.ToTensor(),
319 | lambda x: add_noise(x, nvals=32),
320 | ])
321 | ), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers
322 | )
323 | test_loader = torch.utils.data.DataLoader(
324 | datasets.CelebA5bit(train=False, transform=transforms.Compose([
325 | lambda x: add_noise(x, nvals=32),
326 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers
327 | )
328 | elif args.data == 'imagenet32':
329 | im_dim = 3
330 | init_layer = layers.LogitTransform(0.05)
331 | if args.imagesize != 32:
332 | print('Changing image size to 32.')
333 | args.imagesize = 32
334 |
335 | if args.real:
336 | train_loader = torch.utils.data.DataLoader(
337 | datasets.Imagenet32(train=True, transform=transforms.Compose([
338 | add_noise,
339 | ])), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers
340 | )
341 | test_loader = torch.utils.data.DataLoader(
342 | datasets.Imagenet32(train=False, transform=transforms.Compose([
343 | add_noise,
344 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers
345 | )
346 | elif args.data == 'imagenet64':
347 | im_dim = 3
348 | init_layer = layers.LogitTransform(0.05)
349 | if args.imagesize != 64:
350 | print('Changing image size to 64.')
351 | args.imagesize = 64
352 |
353 | if args.real:
354 | train_loader = torch.utils.data.DataLoader(
355 | datasets.Imagenet64(train=True, transform=transforms.Compose([
356 | add_noise,
357 | ])), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers
358 | )
359 | test_loader = torch.utils.data.DataLoader(
360 | datasets.Imagenet64(train=False, transform=transforms.Compose([
361 | add_noise,
362 | ])), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers
363 | )
364 |
365 | if args.task in ['classification', 'hybrid']:
366 | try:
367 | n_classes
368 | except NameError:
369 | raise ValueError('Cannot perform classification with {}'.format(args.data))
370 | else:
371 | n_classes = 1
372 |
373 | print('Dataset loaded.', flush=True)
374 | print('Creating model.', flush=True)
375 |
376 | input_size = (args.batchsize, im_dim + args.padding, args.imagesize, args.imagesize)
377 |
378 | if args.squeeze_first:
379 | input_size = (input_size[0], input_size[1] * 4, input_size[2] // 2, input_size[3] // 2)
380 | squeeze_layer = layers.SqueezeLayer(2)
381 |
382 | # Model
383 | model = ResidualFlow(
384 | input_size,
385 | n_blocks=list(map(int, args.nblocks.split('-'))),
386 | intermediate_dim=args.idim,
387 | factor_out=args.factor_out,
388 | quadratic=args.quadratic,
389 | init_layer=init_layer,
390 | actnorm=args.actnorm,
391 | fc_actnorm=args.fc_actnorm,
392 | batchnorm=args.batchnorm,
393 | dropout=args.dropout,
394 | fc=args.fc,
395 | coeff=args.coeff,
396 | vnorms=args.vnorms,
397 | n_lipschitz_iters=args.n_lipschitz_iters,
398 | sn_atol=args.sn_tol,
399 | sn_rtol=args.sn_tol,
400 | n_power_series=args.n_power_series,
401 | n_dist=args.n_dist,
402 | n_samples=args.n_samples,
403 | kernels=args.kernels,
404 | activation_fn=args.act,
405 | fc_end=args.fc_end,
406 | fc_idim=args.fc_idim,
407 | n_exact_terms=args.n_exact_terms,
408 | preact=args.preact,
409 | neumann_grad=args.neumann_grad,
410 | grad_in_forward=args.mem_eff,
411 | first_resblock=args.first_resblock,
412 | learn_p=args.learn_p,
413 | block_type=args.block,
414 | )
415 |
416 | model.to(device)
417 |
418 | print('Initializing model.', flush=True)
419 |
420 | with torch.no_grad():
421 | x = torch.rand(1, *input_size[1:]).to(device)
422 | model(x)
423 | print('Restoring from checkpoint.', flush=True)
424 | checkpt = torch.load(args.resume)
425 | state = model.state_dict()
426 | model.load_state_dict(checkpt['state_dict'], strict=True)
427 |
428 | ema = utils.ExponentialMovingAverage(model)
429 | ema.set(checkpt['ema'])
430 | ema.swap()
431 |
432 | print(model, flush=True)
433 |
434 | model.eval()
435 | print('Updating lipschitz.', flush=True)
436 | update_lipschitz(model)
437 |
438 |
439 | def visualize(model):
440 | utils.makedirs('{}_imgs_t{}'.format(args.data, args.temp))
441 |
442 | with torch.no_grad():
443 |
444 | for i in tqdm(range(args.nbatches)):
445 | # random samples
446 | rand_z = torch.randn(args.batchsize, (im_dim + args.padding) * args.imagesize * args.imagesize).to(device)
447 | rand_z = rand_z * args.temp
448 | fake_imgs = model(rand_z, inverse=True).view(-1, *input_size[1:])
449 | if args.squeeze_first: fake_imgs = squeeze_layer.inverse(fake_imgs)
450 | fake_imgs = remove_padding(fake_imgs)
451 | fake_imgs = fake_imgs.view(-1, im_dim, args.imagesize, args.imagesize)
452 | fake_imgs = fake_imgs.cpu()
453 |
454 | if args.save_each:
455 | for j in range(fake_imgs.shape[0]):
456 | save_image(
457 | fake_imgs[j], '{}_imgs_t{}/{}.png'.format(args.data, args.temp, args.batchsize * i + j), nrow=1,
458 | padding=0, range=(0, 1), pad_value=0
459 | )
460 | else:
461 | save_image(
462 | fake_imgs, 'imgs/{}_t{}_samples{}.png'.format(args.data, args.temp, i), nrow=W, padding=2,
463 | range=(0, 1), pad_value=1
464 | )
465 |
466 |
467 | real_imgs = test_loader.__iter__().__next__()[0] if args.real else None
468 | if args.real:
469 | real_imgs = test_loader.__iter__().__next__()[0]
470 | save_image(
471 | real_imgs.cpu().float(), 'imgs/{}_real.png'.format(args.data), nrow=W, padding=2, range=(0, 1), pad_value=1
472 | )
473 |
474 | visualize(model)
475 |
--------------------------------------------------------------------------------
/run_cifar10.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train_img.py --data cifar10 --actnorm True \
2 | --nblocks '2-2-2' --idim '512' --act 'swish' --kernels '3-1-3' --vnorms '2222' --fc-end False --preact True \
3 | --save 'experiments/cifar10(blocks_2*3(512,k313)_swish_nofc_preact_10term' --coeff 0.9 --n-exact-terms 10
4 |
--------------------------------------------------------------------------------
/run_classification.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train_classification.py --data 'cifar100' --model-dir 'experiments/classify_cifar100_Resnet18_c0.9' \
2 | --weight-decay 0 --epochs 150 --log-interval 20 --batch-size 128 --test-batch-size 128 --lr 0.001 --coeff 0.9
--------------------------------------------------------------------------------
/run_tabular.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train_tabular.py --nblocks 20 --vnorms '222222' --dims '128-128-128-128' \
2 | --save 'experiments/tabular_(power_block20,128*4,c99,sin)_bf' --act 'sin' --data 'power' --batchsize 1000 --coeff 0.99 --nepochs 10000 --epsf 1e-5
3 |
--------------------------------------------------------------------------------
/run_toy.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train_toy.py --nblocks 6 --vnorms '2222' --dims '128-128' \
2 | --arch 'implicit' --brute-force True --save 'experiments/res_toy(block6,128*2,c99,sin,5000)' --act 'sin' --data 'checkerboard' --batch_size 5000 --coeff 0.99 --n-lipschitz-iters 20
--------------------------------------------------------------------------------
/train_classification.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | import lib.utils as utils
7 |
8 | import os
9 | import argparse
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import torchvision
14 | import torch.optim as optim
15 | from torchvision import datasets, transforms
16 | import time
17 | import lib.layers.base as base_layers
18 | import lib.layers as layers
19 |
20 | parser = argparse.ArgumentParser(description='PyTorch CIFAR TRADES Adversarial Training')
21 | parser.add_argument(
22 | '--data', type=str, default='cifar10', choices=[
23 | 'cifar10',
24 | 'cifar100',
25 | 'mnist',
26 | ]
27 | )
28 | parser.add_argument('--batch-size', type=int, default=128, metavar='N',
29 | help='input batch size for training (default: 128)')
30 | parser.add_argument('--test-batch-size', type=int, default=128, metavar='N',
31 | help='input batch size for testing (default: 128)')
32 | parser.add_argument('--epochs', type=int, default=76, metavar='N',
33 | help='number of epochs to train')
34 | parser.add_argument('--weight-decay', '--wd', default=2e-4,
35 | type=float, metavar='W')
36 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
37 | help='learning rate')
38 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
39 | help='SGD momentum')
40 | parser.add_argument('--no-cuda', action='store_true', default=False,
41 | help='disables CUDA training')
42 | parser.add_argument('--epsilon', default=0.031,
43 | help='perturbation')
44 | parser.add_argument('--num-steps', default=10,
45 | help='perturb number of steps')
46 | parser.add_argument('--step-size', default=0.007,
47 | help='perturb step size')
48 | parser.add_argument('--beta', default=6.0,
49 | help='regularization, i.e., 1/lambda in TRADES')
50 | parser.add_argument('--seed', type=int, default=1, metavar='S',
51 | help='random seed (default: 1)')
52 | parser.add_argument('--log-interval', type=int, default=100, metavar='N',
53 | help='how many batches to wait before logging training status')
54 | parser.add_argument('--model-dir', default='./experiments/model-cifar-Resnet18',
55 | help='directory of model for saving checkpoint')
56 | parser.add_argument('--save-freq', '-s', default=50, type=int, metavar='N',
57 | help='save frequency')
58 | parser.add_argument('--coeff', type=float, default=0.99)
59 |
60 | args = parser.parse_args()
61 |
62 | # settings
63 | model_dir = args.model_dir
64 | utils.makedirs(model_dir)
65 | logger = utils.get_logger(logpath=os.path.join(model_dir, 'logs'), filepath=os.path.abspath(__file__))
66 | logger.info(args)
67 |
68 | use_cuda = not args.no_cuda and torch.cuda.is_available()
69 | if use_cuda:
70 | torch.backends.cudnn.benchmark = True
71 | np.random.seed(args.seed)
72 | torch.manual_seed(args.seed)
73 | torch.cuda.manual_seed(args.seed)
74 | device = torch.device("cuda" if use_cuda else "cpu")
75 | kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
76 |
77 |
78 | ACTIVATION_FNS = {
79 | 'identity': base_layers.Identity,
80 | 'relu': torch.nn.ReLU,
81 | 'tanh': torch.nn.Tanh,
82 | 'elu': torch.nn.ELU,
83 | 'selu': torch.nn.SELU,
84 | 'fullsort': base_layers.FullSort,
85 | 'maxmin': base_layers.MaxMin,
86 | 'swish': base_layers.Swish,
87 | 'lcube': base_layers.LipschitzCube,
88 | 'sin': base_layers.Sin,
89 | 'zero': base_layers.Zero,
90 | }
91 |
92 | class BasicBlock(nn.Module):
93 | expansion = 1
94 |
95 | def __init__(self, in_planes, hidden, planes, stride=1):
96 | super(BasicBlock, self).__init__()
97 |
98 | def build_net():
99 | layer = nn.Conv2d
100 | nnet = []
101 | nnet.append(
102 | layer(
103 | in_planes, hidden, kernel_size=3, stride=1, padding=1, bias=False
104 | )
105 | )
106 | nnet.append(nn.BatchNorm2d(hidden))
107 | nnet.append(ACTIVATION_FNS['relu']())
108 | nnet.append(
109 | layer(
110 | hidden, in_planes, kernel_size=3, stride=1, padding=1, bias=False
111 | )
112 | )
113 | nnet.append(nn.BatchNorm2d(in_planes))
114 | nnet.append(ACTIVATION_FNS['relu']())
115 | return nn.Sequential(*nnet)
116 |
117 | self.block1 = build_net()
118 | self.block2 = build_net()
119 |
120 | self.downsample = nn.Sequential()
121 | if stride != 1 or in_planes != self.expansion * planes:
122 | self.downsample = nn.Sequential(
123 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
124 | nn.BatchNorm2d(self.expansion * planes),
125 | ACTIVATION_FNS['relu'](),
126 | )
127 |
128 | def forward(self, x):
129 | out = F.relu(x + self.block1(x))
130 | out = out + self.block2(out)
131 | out = self.downsample(out)
132 | return out
133 |
134 |
135 | class BasicImplicitBlock(nn.Module):
136 | expansion = 1
137 |
138 | def __init__(
139 | self,
140 | in_planes,
141 | hidden,
142 | planes,
143 | stride=1,
144 | n_lipschitz_iters=None,
145 | sn_atol=1e-3,
146 | sn_rtol=1e-3,
147 | ):
148 | super(BasicImplicitBlock, self).__init__()
149 | coeff = args.coeff
150 | self.initialized = False
151 |
152 | def build_net():
153 | layer = base_layers.get_conv2d
154 | nnet = []
155 | nnet.append(
156 | layer(
157 | in_planes, hidden, kernel_size=3, stride=1, padding=1, bias=False, coeff=coeff, n_iterations=n_lipschitz_iters, domain=2, codomain=2, atol=sn_atol, rtol=sn_rtol,
158 | )
159 | )
160 | nnet.append(ACTIVATION_FNS['relu']())
161 | nnet.append(
162 | layer(
163 | hidden, in_planes, kernel_size=3, stride=1, padding=1, bias=False, coeff=coeff, n_iterations=n_lipschitz_iters, domain=2, codomain=2, atol=sn_atol, rtol=sn_rtol,
164 | )
165 | )
166 | nnet.append(ACTIVATION_FNS['relu']())
167 | return nn.Sequential(*nnet)
168 |
169 | self.block = layers.imBlock(
170 | build_net(),
171 | build_net(),
172 | )
173 | self.downsample = nn.Sequential()
174 | if stride != 1 or in_planes != self.expansion * planes:
175 | self.downsample = nn.Sequential(
176 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
177 | nn.BatchNorm2d(self.expansion * planes),
178 | ACTIVATION_FNS['relu'](),
179 | )
180 |
181 | def forward(self, x):
182 | if self.initialized:
183 | out = self.block(x)
184 | else:
185 | out = self.block(x, restore=True)
186 | self.initialized = True
187 | out = self.downsample(out)
188 | return out
189 |
190 |
191 | class Bottleneck(nn.Module):
192 | expansion = 4
193 |
194 | def __init__(self, in_planes, planes, stride=1):
195 | super(Bottleneck, self).__init__()
196 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
197 | self.bn1 = nn.BatchNorm2d(planes)
198 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
199 | self.bn2 = nn.BatchNorm2d(planes)
200 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
201 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
202 |
203 | self.shortcut = nn.Sequential()
204 | if stride != 1 or in_planes != self.expansion * planes:
205 | self.shortcut = nn.Sequential(
206 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
207 | nn.BatchNorm2d(self.expansion * planes)
208 | )
209 |
210 | def forward(self, x):
211 | out = F.relu(self.bn1(self.conv1(x)))
212 | out = F.relu(self.bn2(self.conv2(out)))
213 | out = self.bn3(self.conv3(out))
214 | out += self.shortcut(x)
215 | out = F.relu(out)
216 | return out
217 |
218 |
219 | class ResNet(nn.Module):
220 | def __init__(self, block, num_blocks, num_classes=10):
221 | super(ResNet, self).__init__()
222 | self.in_planes = 64
223 |
224 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
225 | self.bn1 = nn.BatchNorm2d(64)
226 | self.layer1 = self._make_layer(block, 64, 64, num_blocks[0], stride=1)
227 | self.layer2 = self._make_layer(block, 128, 128, num_blocks[1], stride=2)
228 | self.layer3 = self._make_layer(block, 256, 256, num_blocks[2], stride=2)
229 | self.layer4 = self._make_layer(block, 512, 512, num_blocks[3], stride=2)
230 | self.linear = nn.Linear(512 * block.expansion, num_classes)
231 |
232 | def _make_layer(self, block, hidden, planes, num_blocks, stride):
233 | strides = [stride] + [1] * (num_blocks - 1)
234 | layers = []
235 | for stride in strides:
236 | layers.append(block(self.in_planes, hidden, planes, stride))
237 | self.in_planes = planes * block.expansion
238 | return nn.Sequential(*layers)
239 |
240 | def forward(self, x):
241 | out = F.relu(self.bn1(self.conv1(x)))
242 | out = self.layer1(out)
243 | out = self.layer2(out)
244 | out = self.layer3(out)
245 | out = self.layer4(out)
246 | out = F.avg_pool2d(out, 4)
247 | out = out.view(out.size(0), -1)
248 | out = self.linear(out)
249 | return out
250 |
251 |
252 | class ImplicitResNet(nn.Module):
253 | def __init__(self, block, num_blocks, num_classes=10):
254 | super(ImplicitResNet, self).__init__()
255 | self.in_planes = 64
256 |
257 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
258 | self.bn1 = nn.BatchNorm2d(64)
259 | self.layer1 = self._make_layer(block, 64, 64, num_blocks[0], stride=1)
260 | self.layer2 = self._make_layer(block, 128, 128, num_blocks[1], stride=2)
261 | self.layer3 = self._make_layer(block, 256, 256, num_blocks[2], stride=2)
262 | self.layer4 = self._make_layer(block, 512, 512, num_blocks[3], stride=2)
263 | self.linear = nn.Linear(512 * block.expansion, num_classes)
264 |
265 | def _make_layer(self, block, hidden, planes, num_blocks, stride):
266 | strides = [stride] + [1] * (num_blocks - 1)
267 | layers = []
268 | for stride in strides:
269 | layers.append(block(self.in_planes, hidden, planes, stride))
270 | self.in_planes = planes * block.expansion
271 | return nn.Sequential(*layers)
272 |
273 | def forward(self, x):
274 | out = F.relu(self.bn1(self.conv1(x)))
275 | out = self.layer1(out)
276 | out = self.layer2(out)
277 | out = self.layer3(out)
278 | out = self.layer4(out)
279 | out = F.avg_pool2d(out, 4)
280 | out = out.view(out.size(0), -1)
281 | out = self.linear(out)
282 | return out
283 |
284 |
285 | def ResNet18(num_classes=10):
286 | return ResNet(BasicBlock, [1, 1, 1, 1], num_classes=num_classes)
287 |
288 | def ImplicitResNet18(num_classes=10):
289 | return ImplicitResNet(BasicImplicitBlock, [1, 1, 1, 1], num_classes=num_classes)
290 |
291 | def ResNet34():
292 | return ResNet(BasicBlock, [3, 4, 6, 3])
293 |
294 |
295 | def ResNet50():
296 | return ResNet(Bottleneck, [3, 4, 6, 3])
297 |
298 |
299 | def ResNet101():
300 | return ResNet(Bottleneck, [3, 4, 23, 3])
301 |
302 |
303 | def ResNet152():
304 | return ResNet(Bottleneck, [3, 8, 36, 3])
305 |
306 |
307 | # setup data loader
308 | transform_train = [
309 | transforms.RandomCrop(32, padding=4),
310 | transforms.RandomHorizontalFlip(),
311 | transforms.ToTensor(),
312 | ]
313 | transform_test = [
314 | transforms.ToTensor(),
315 | ]
316 |
317 | if args.data == 'cifar10':
318 | trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transforms.Compose(transform_train))
319 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs)
320 | testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transforms.Compose(transform_test))
321 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, **kwargs)
322 | args.num_classes = 10
323 | elif args.data == 'cifar100':
324 | trainset = torchvision.datasets.CIFAR100(root='data', train=True, download=True, transform=transforms.Compose(transform_train))
325 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs)
326 | testset = torchvision.datasets.CIFAR100(root='data', train=False, download=True, transform=transforms.Compose(transform_test))
327 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, **kwargs)
328 | args.num_classes = 100
329 | elif args.data == 'mnist':
330 | transform_mnist = [transforms.Pad(2, 0),]
331 | transform_mnist2 = [lambda x: x.repeat((3, 1, 1))]
332 | trainset = torchvision.datasets.MNIST(root='data', train=True, download=True, transform=transforms.Compose(transform_mnist + transform_train + transform_mnist2))
333 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs)
334 | testset = torchvision.datasets.MNIST(root='data', train=False, download=True, transform=transforms.Compose(transform_mnist + transform_test + transform_mnist2))
335 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, **kwargs)
336 | args.num_classes = 10
337 |
338 | batch_time = utils.RunningAverageMeter(0.97)
339 | loss_meter = utils.RunningAverageMeter(0.97)
340 |
341 |
342 | def update_lipschitz(model):
343 | with torch.no_grad():
344 | for m in model.modules():
345 | if isinstance(m, base_layers.SpectralNormConv2d) or isinstance(m, base_layers.SpectralNormLinear):
346 | m.compute_weight(update=True)
347 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
348 | m.compute_weight(update=True)
349 |
350 |
351 | def train(args, model, device, train_loader, optimizer, epoch, ema):
352 | model.train()
353 | end = time.time()
354 | for batch_idx, (data, target) in enumerate(train_loader):
355 | data, target = data.to(device), target.to(device)
356 | output = model(data)
357 | loss = F.cross_entropy(output, target, size_average=False)
358 | loss_meter.update(loss.item())
359 |
360 | loss.backward()
361 | optimizer.step()
362 | optimizer.zero_grad()
363 | update_lipschitz(model)
364 | ema.apply()
365 |
366 | batch_time.update(time.time() - end)
367 | end = time.time()
368 |
369 | # print progress
370 | if batch_idx % args.log_interval == 0:
371 | logger.info('Train Epoch: {} [{}/{} ({:.0f}%)] | Time {batch_time.val:.3f} | Loss: {loss_meter.val:.6f}'.format(
372 | epoch, batch_idx * len(data), len(train_loader.dataset),
373 | 100. * batch_idx / len(train_loader), batch_time=batch_time, loss_meter=loss_meter))
374 |
375 |
376 | def eval_train(model, device, train_loader, ema):
377 | ema.swap()
378 | update_lipschitz(model)
379 | model.eval()
380 | train_loss = 0
381 | correct = 0
382 | with torch.no_grad():
383 | for data, target in train_loader:
384 | data, target = data.to(device), target.to(device)
385 | output = model(data)
386 | train_loss += F.cross_entropy(output, target, size_average=False).item()
387 | pred = output.max(1, keepdim=True)[1]
388 | correct += pred.eq(target.view_as(pred)).sum().item()
389 | train_loss /= len(train_loader.dataset)
390 | logger.info('Training: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
391 | train_loss, correct, len(train_loader.dataset),
392 | 100. * correct / len(train_loader.dataset)))
393 | training_accuracy = correct / len(train_loader.dataset)
394 | ema.swap()
395 | return train_loss, training_accuracy
396 |
397 |
398 | def eval_test(model, device, test_loader, ema):
399 | ema.swap()
400 | model.eval()
401 | test_loss = 0
402 | correct = 0
403 | with torch.no_grad():
404 | for data, target in test_loader:
405 | data, target = data.to(device), target.to(device)
406 | output = model(data)
407 | test_loss += F.cross_entropy(output, target, size_average=False).item()
408 | pred = output.max(1, keepdim=True)[1]
409 | correct += pred.eq(target.view_as(pred)).sum().item()
410 | test_loss /= len(test_loader.dataset)
411 | logger.info('Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
412 | test_loss, correct, len(test_loader.dataset),
413 | 100. * correct / len(test_loader.dataset)))
414 | test_accuracy = correct / len(test_loader.dataset)
415 | ema.swap()
416 | return test_loss, test_accuracy
417 |
418 |
419 | def adjust_learning_rate(optimizer, epoch):
420 | """decrease the learning rate"""
421 | lr = args.lr
422 | if epoch >= 75:
423 | lr = args.lr * 0.1
424 | if epoch >= 90:
425 | lr = args.lr * 0.01
426 | if epoch >= 100:
427 | lr = args.lr * 0.001
428 | for param_group in optimizer.param_groups:
429 | param_group['lr'] = lr
430 |
431 |
432 | def main():
433 | # model = torch.nn.DataParallel(ResNet18(num_classes=args.num_classes).to(device))
434 | model = torch.nn.DataParallel(ImplicitResNet18(num_classes=args.num_classes).to(device))
435 | with torch.no_grad():
436 | x, _ = next(iter(train_loader))
437 | x = x.to(device)
438 | model(x)
439 | ema = utils.ExponentialMovingAverage(model)
440 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay)
441 | # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
442 | logger.info(model)
443 | logger.info('EMA: {}'.format(ema))
444 | logger.info(optimizer)
445 |
446 | for epoch in range(1, args.epochs + 1):
447 | # adjust learning rate for SGD
448 | adjust_learning_rate(optimizer, epoch)
449 |
450 | # adversarial training
451 | train(args, model, device, train_loader, optimizer, epoch, ema)
452 |
453 | # evaluation on natural examples
454 | logger.info('================================================================')
455 | eval_train(model, device, train_loader, ema)
456 | eval_test(model, device, test_loader, ema)
457 | logger.info('================================================================')
458 |
459 | # save checkpoint
460 | if epoch == args.epochs:
461 | torch.save(model.state_dict(),
462 | os.path.join(model_dir, 'model-wideres-epoch{}.pt'.format(epoch)))
463 | torch.save(optimizer.state_dict(),
464 | os.path.join(model_dir, 'opt-wideres-checkpoint_epoch{}.tar'.format(epoch)))
465 |
466 |
467 | if __name__ == '__main__':
468 | main()
--------------------------------------------------------------------------------
/train_toy.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg')
3 | import matplotlib.pyplot as plt
4 |
5 | import argparse
6 | import os
7 | import time
8 | import math
9 | import numpy as np
10 |
11 | import torch
12 |
13 | import lib.optimizers as optim
14 | import lib.layers.base as base_layers
15 | import lib.layers as layers
16 | import lib.toy_data as toy_data
17 | import lib.utils as utils
18 | from lib.visualize_flow import visualize_transform
19 |
20 |
21 | ACTIVATION_FNS = {
22 | 'identity': base_layers.Identity,
23 | 'relu': torch.nn.ReLU,
24 | 'tanh': torch.nn.Tanh,
25 | 'elu': torch.nn.ELU,
26 | 'selu': torch.nn.SELU,
27 | 'fullsort': base_layers.FullSort,
28 | 'maxmin': base_layers.MaxMin,
29 | 'swish': base_layers.Swish,
30 | 'lcube': base_layers.LipschitzCube,
31 | 'sin': base_layers.Sin,
32 | }
33 |
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument(
36 | '--data', choices=['swissroll', '8gaussians', 'pinwheel', 'circles', 'moons', '2spirals', 'checkerboard', 'rings'],
37 | type=str, default='pinwheel'
38 | )
39 | parser.add_argument('--arch', choices=['iresnet', 'realnvp', 'implicit'], default='implicit')
40 | parser.add_argument('--coeff', type=float, default=0.9)
41 | parser.add_argument('--vnorms', type=str, default='222222')
42 | parser.add_argument('--n-lipschitz-iters', type=int, default=5)
43 | parser.add_argument('--atol', type=float, default=None)
44 | parser.add_argument('--rtol', type=float, default=None)
45 | parser.add_argument('--learn-p', type=eval, choices=[True, False], default=False)
46 | parser.add_argument('--mixed', type=eval, choices=[True, False], default=True)
47 |
48 | parser.add_argument('--dims', type=str, default='128-128-128-128')
49 | parser.add_argument('--act', type=str, choices=ACTIVATION_FNS.keys(), default='sin')
50 | parser.add_argument('--nblocks', type=int, default=100)
51 | parser.add_argument('--brute-force', type=eval, choices=[True, False], default=False)
52 | parser.add_argument('--actnorm', type=eval, choices=[True, False], default=False)
53 | parser.add_argument('--batchnorm', type=eval, choices=[True, False], default=False)
54 | parser.add_argument('--exact-trace', type=eval, choices=[True, False], default=False)
55 | parser.add_argument('--n-power-series', type=int, default=None)
56 | parser.add_argument('--n-samples', type=int, default=1)
57 | parser.add_argument('--n-dist', choices=['geometric', 'poisson'], default='geometric')
58 |
59 | parser.add_argument('--niters', type=int, default=50000)
60 | parser.add_argument('--batch_size', type=int, default=1000)
61 | parser.add_argument('--test_batch_size', type=int, default=10000)
62 | parser.add_argument('--lr', type=float, default=1e-3)
63 | parser.add_argument('--weight-decay', type=float, default=1e-5)
64 | parser.add_argument('--annealing-iters', type=int, default=0)
65 |
66 | parser.add_argument('--resume', type=str, default=None)
67 | parser.add_argument('--begin-epoch', type=int, default=0)
68 |
69 | parser.add_argument('--save', type=str, default='experiments/iresnet_toy')
70 | parser.add_argument('--viz_freq', type=int, default=1000)
71 | parser.add_argument('--val_freq', type=int, default=1000)
72 | parser.add_argument('--log_freq', type=int, default=100)
73 | parser.add_argument('--gpu', type=int, default=0)
74 | parser.add_argument('--seed', type=int, default=0)
75 | args = parser.parse_args()
76 |
77 | # logger
78 | utils.makedirs(args.save)
79 | logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))
80 | logger.info(args)
81 |
82 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
83 |
84 | np.random.seed(args.seed)
85 | torch.manual_seed(args.seed)
86 | if device.type == 'cuda':
87 | torch.cuda.manual_seed(args.seed)
88 |
89 |
90 | def count_parameters(model):
91 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
92 |
93 |
94 | def standard_normal_sample(size):
95 | return torch.randn(size)
96 |
97 |
98 | def standard_normal_logprob(z):
99 | logZ = -0.5 * np.log(2 * np.pi)
100 | return logZ - z.pow(2) / 2
101 |
102 |
103 | def compute_loss(args, model, batch_size=None, beta=1.):
104 | if batch_size is None: batch_size = args.batch_size
105 |
106 | # load data
107 | x = toy_data.inf_train_gen(args.data, batch_size=batch_size)
108 | x = torch.from_numpy(x).type(torch.float32).to(device)
109 | zero = torch.zeros(x.shape[0], 1).to(x)
110 |
111 | # transform to z
112 | z, delta_logp = model(x, zero)
113 |
114 | # compute log p(z)
115 | logpz = standard_normal_logprob(z).sum(1, keepdim=True)
116 |
117 | logpx = logpz - beta * delta_logp
118 | loss = -torch.mean(logpx)
119 | return loss, torch.mean(logpz), torch.mean(-delta_logp)
120 |
121 |
122 | def parse_vnorms():
123 | ps = []
124 | for p in args.vnorms:
125 | if p == 'f':
126 | ps.append(float('inf'))
127 | else:
128 | ps.append(float(p))
129 | return ps[:-1], ps[1:]
130 |
131 |
132 | def compute_p_grads(model):
133 | scales = 0.
134 | nlayers = 0
135 | for m in model.modules():
136 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
137 | scales = scales + m.compute_one_iter()
138 | nlayers += 1
139 | scales.mul(1 / nlayers).mul(0.01).backward()
140 | for m in model.modules():
141 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
142 | if m.domain.grad is not None and torch.isnan(m.domain.grad):
143 | m.domain.grad = None
144 |
145 |
146 | def build_nnet(dims, activation_fn=torch.nn.ReLU):
147 | nnet = []
148 | domains, codomains = parse_vnorms()
149 | if args.learn_p:
150 | if args.mixed:
151 | domains = [torch.nn.Parameter(torch.tensor(0.)) for _ in domains]
152 | else:
153 | domains = [torch.nn.Parameter(torch.tensor(0.))] * len(domains)
154 | codomains = domains[1:] + [domains[0]]
155 | for i, (in_dim, out_dim, domain, codomain) in enumerate(zip(dims[:-1], dims[1:], domains, codomains)):
156 | if i > 0:
157 | nnet.append(activation_fn())
158 | nnet.append(
159 | base_layers.get_linear(
160 | in_dim,
161 | out_dim,
162 | coeff=args.coeff,
163 | n_iterations=args.n_lipschitz_iters,
164 | atol=args.atol,
165 | rtol=args.rtol,
166 | domain=domain,
167 | codomain=codomain,
168 | zero_init=(out_dim == 2),
169 | )
170 | )
171 | return torch.nn.Sequential(*nnet)
172 |
173 |
174 | def update_lipschitz(model, n_iterations):
175 | for m in model.modules():
176 | if isinstance(m, base_layers.SpectralNormConv2d) or isinstance(m, base_layers.SpectralNormLinear):
177 | m.compute_weight(update=True, n_iterations=n_iterations)
178 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
179 | m.compute_weight(update=True, n_iterations=n_iterations)
180 |
181 |
182 | def get_ords(model):
183 | ords = []
184 | for m in model.modules():
185 | if isinstance(m, base_layers.InducedNormConv2d) or isinstance(m, base_layers.InducedNormLinear):
186 | domain, codomain = m.compute_domain_codomain()
187 | if torch.is_tensor(domain):
188 | domain = domain.item()
189 | if torch.is_tensor(codomain):
190 | codomain = codomain.item()
191 | ords.append(domain)
192 | ords.append(codomain)
193 | return ords
194 |
195 |
196 | def pretty_repr(a):
197 | return '[[' + ','.join(list(map(lambda i: f'{i:.2f}', a))) + ']]'
198 |
199 |
200 | if __name__ == '__main__':
201 |
202 | activation_fn = ACTIVATION_FNS[args.act]
203 |
204 | if args.arch == 'iresnet':
205 | dims = [2] + list(map(int, args.dims.split('-'))) + [2]
206 | blocks = []
207 | if args.actnorm: blocks.append(layers.ActNorm1d(2))
208 | for _ in range(args.nblocks):
209 | blocks.append(
210 | layers.iResBlock(
211 | build_nnet(dims, activation_fn),
212 | n_dist=args.n_dist,
213 | n_power_series=args.n_power_series,
214 | exact_trace=args.exact_trace,
215 | brute_force=args.brute_force,
216 | n_samples=args.n_samples,
217 | neumann_grad=False,
218 | grad_in_forward=False,
219 | )
220 | )
221 | if args.actnorm: blocks.append(layers.ActNorm1d(2))
222 | if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2))
223 | model = layers.SequentialFlow(blocks).to(device)
224 | elif args.arch == 'implicit':
225 | dims = [2] + list(map(int, args.dims.split('-'))) + [2]
226 | blocks = []
227 | if args.actnorm: blocks.append(layers.ActNorm1d(2))
228 | for _ in range(args.nblocks):
229 | blocks.append(
230 | layers.imBlock(
231 | build_nnet(dims, activation_fn),
232 | build_nnet(dims, activation_fn),
233 | n_dist=args.n_dist,
234 | n_power_series=args.n_power_series,
235 | exact_trace=args.exact_trace,
236 | brute_force=args.brute_force,
237 | n_samples=args.n_samples,
238 | neumann_grad=False,
239 | grad_in_forward=False, # toy data needn't save memory
240 | )
241 | )
242 | model = torch.nn.DataParallel(layers.SequentialFlow(blocks).to(device))
243 | elif args.arch == 'realnvp':
244 | blocks = []
245 | for _ in range(args.nblocks):
246 | blocks.append(layers.CouplingBlock(2, swap=False))
247 | blocks.append(layers.CouplingBlock(2, swap=True))
248 | if args.actnorm: blocks.append(layers.ActNorm1d(2))
249 | if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2))
250 | model = layers.SequentialFlow(blocks).to(device)
251 |
252 | logger.info(model)
253 | logger.info("Number of trainable parameters: {}".format(count_parameters(model)))
254 |
255 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
256 |
257 | time_meter = utils.RunningAverageMeter(0.93)
258 | loss_meter = utils.RunningAverageMeter(0.93)
259 | logpz_meter = utils.RunningAverageMeter(0.93)
260 | delta_logp_meter = utils.RunningAverageMeter(0.93)
261 |
262 | end = time.time()
263 |
264 | if (args.resume is not None):
265 | logger.info('Resuming model from {}'.format(args.resume))
266 | with torch.no_grad():
267 | x = toy_data.inf_train_gen(args.data, batch_size=args.batch_size)
268 | x = torch.from_numpy(x).type(torch.float32).to(device)
269 | model(x, restore=True)
270 | checkpt = torch.load(args.resume)
271 | sd = {k: v for k, v in checkpt['state_dict'].items() if 'last_n_samples' not in k}
272 | state = model.state_dict()
273 | state.update(sd)
274 | model.load_state_dict(state, strict=True)
275 | del checkpt
276 | del state
277 | else:
278 | with torch.no_grad():
279 | x = toy_data.inf_train_gen(args.data, batch_size=args.batch_size)
280 | x = torch.from_numpy(x).type(torch.float32).to(device)
281 | model(x, restore=True)
282 |
283 | best_loss = float('inf')
284 | model.train()
285 | for itr in range(1, args.niters + 1):
286 | optimizer.zero_grad()
287 |
288 | beta = min(1, itr / args.annealing_iters) if args.annealing_iters > 0 else 1.
289 | loss, logpz, delta_logp = compute_loss(args, model, beta=beta)
290 | loss_meter.update(loss.item())
291 | logpz_meter.update(logpz.item())
292 | delta_logp_meter.update(delta_logp.item())
293 | loss.backward()
294 | if args.learn_p and itr > args.annealing_iters: compute_p_grads(model)
295 | optimizer.step()
296 | update_lipschitz(model, args.n_lipschitz_iters)
297 |
298 | time_meter.update(time.time() - end)
299 |
300 | if itr % args.log_freq == 0:
301 | logger.info(
302 | 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f})'
303 | ' | Logp(z) {:.6f}({:.6f}) | DeltaLogp {:.6f}({:.6f})'.format(
304 | itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, logpz_meter.val, logpz_meter.avg,
305 | delta_logp_meter.val, delta_logp_meter.avg
306 | )
307 | )
308 |
309 | if itr % args.val_freq == 0 or itr == args.niters:
310 | update_lipschitz(model, 200)
311 | with torch.no_grad():
312 | model.eval()
313 | test_loss, test_logpz, test_delta_logp = compute_loss(args, model, batch_size=args.test_batch_size)
314 | log_message = (
315 | '[TEST] Iter {:04d} | Test Loss {:.6f} '
316 | '| Test Logp(z) {:.6f} | Test DeltaLogp {:.6f}'.format(
317 | itr, test_loss.item(), test_logpz.item(), test_delta_logp.item()
318 | )
319 | )
320 | logger.info(log_message)
321 |
322 | logger.info('Ords: {}'.format(pretty_repr(get_ords(model))))
323 |
324 | if test_loss.item() < best_loss:
325 | best_loss = test_loss.item()
326 | utils.makedirs(args.save)
327 | torch.save({
328 | 'args': args,
329 | 'state_dict': model.state_dict(),
330 | }, os.path.join(args.save, 'checkpt.pth'))
331 | model.train()
332 |
333 | if itr == 1 or itr % args.viz_freq == 0:
334 | with torch.no_grad():
335 | model.eval()
336 | p_samples = toy_data.inf_train_gen(args.data, batch_size=20000)
337 |
338 | sample_fn, density_fn = model.module.inverse, model.forward
339 |
340 | plt.figure(figsize=(9, 3))
341 | visualize_transform(
342 | p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn,
343 | samples=True, npts=400, device=device
344 | )
345 | fig_filename = os.path.join(args.save, 'figs', '{:04d}.jpg'.format(itr))
346 | utils.makedirs(os.path.dirname(fig_filename))
347 | plt.savefig(fig_filename)
348 | plt.close()
349 | model.train()
350 |
351 | end = time.time()
352 |
353 | logger.info('Training has finished.')
354 |
--------------------------------------------------------------------------------