├── Readme.md ├── assets └── moon_flow.png ├── nflib ├── __init__.py ├── flows.py ├── made.py ├── nets.py └── spline_flows.py └── nflib1.ipynb /Readme.md: -------------------------------------------------------------------------------- 1 | # pytorch-normalizing-flows 2 | 3 | Implementions of normalizing flows (NICE, RealNVP, MAF, IAF, Neural Splines Flows, etc) in PyTorch. 4 | 5 | ![Normalizing Flow fitting a 2D dataset](https://github.com/karpathy/pytorch-normalizing-flows/blob/master/assets/moon_flow.png) 6 | 7 | **todos** 8 | - TODO: make work on GPU 9 | - TODO: 2D -> ND: get (flat) using MNIST 10 | - TODO: ND -> images (multi-scale architectures, Glow nets, etc) on MNIST/CIFAR/ImageNet 11 | - TODO: more stable residual-like IAF-style updates (tried but didn't work too well) 12 | - TODO: parallel wavenet 13 | - TODO: radial/planar 2D flows from Rezende Mohamed 2015? 14 | -------------------------------------------------------------------------------- /assets/moon_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/karpathy/pytorch-normalizing-flows/b60e119b37be10ce2930ef9fa17e58686aaf2b3d/assets/moon_flow.png -------------------------------------------------------------------------------- /nflib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/karpathy/pytorch-normalizing-flows/b60e119b37be10ce2930ef9fa17e58686aaf2b3d/nflib/__init__.py -------------------------------------------------------------------------------- /nflib/flows.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements various flows. 3 | Each flow is invertible so it can be forward()ed and backward()ed. 4 | Notice that backward() is not backward as in backprop but simply inversion. 5 | Each flow also outputs its log det J "regularization" 6 | 7 | Reference: 8 | 9 | NICE: Non-linear Independent Components Estimation, Dinh et al. 2014 10 | https://arxiv.org/abs/1410.8516 11 | 12 | Variational Inference with Normalizing Flows, Rezende and Mohamed 2015 13 | https://arxiv.org/abs/1505.05770 14 | 15 | Density estimation using Real NVP, Dinh et al. May 2016 16 | https://arxiv.org/abs/1605.08803 17 | (Laurent's extension of NICE) 18 | 19 | Improved Variational Inference with Inverse Autoregressive Flow, Kingma et al June 2016 20 | https://arxiv.org/abs/1606.04934 21 | (IAF) 22 | 23 | Masked Autoregressive Flow for Density Estimation, Papamakarios et al. May 2017 24 | https://arxiv.org/abs/1705.07057 25 | "The advantage of Real NVP compared to MAF and IAF is that it can both generate data and estimate densities with one forward pass only, whereas MAF would need D passes to generate data and IAF would need D passes to estimate densities." 26 | (MAF) 27 | 28 | Glow: Generative Flow with Invertible 1x1 Convolutions, Kingma and Dhariwal, Jul 2018 29 | https://arxiv.org/abs/1807.03039 30 | 31 | "Normalizing Flows for Probabilistic Modeling and Inference" 32 | https://arxiv.org/abs/1912.02762 33 | (review paper) 34 | """ 35 | 36 | import numpy as np 37 | import torch 38 | import torch.nn.functional as F 39 | from torch import nn 40 | 41 | from nflib.nets import LeafParam, MLP, ARMLP 42 | 43 | class AffineConstantFlow(nn.Module): 44 | """ 45 | Scales + Shifts the flow by (learned) constants per dimension. 46 | In NICE paper there is a Scaling layer which is a special case of this where t is None 47 | """ 48 | def __init__(self, dim, scale=True, shift=True): 49 | super().__init__() 50 | self.s = nn.Parameter(torch.randn(1, dim, requires_grad=True)) if scale else None 51 | self.t = nn.Parameter(torch.randn(1, dim, requires_grad=True)) if shift else None 52 | 53 | def forward(self, x): 54 | s = self.s if self.s is not None else x.new_zeros(x.size()) 55 | t = self.t if self.t is not None else x.new_zeros(x.size()) 56 | z = x * torch.exp(s) + t 57 | log_det = torch.sum(s, dim=1) 58 | return z, log_det 59 | 60 | def backward(self, z): 61 | s = self.s if self.s is not None else z.new_zeros(z.size()) 62 | t = self.t if self.t is not None else z.new_zeros(z.size()) 63 | x = (z - t) * torch.exp(-s) 64 | log_det = torch.sum(-s, dim=1) 65 | return x, log_det 66 | 67 | 68 | class ActNorm(AffineConstantFlow): 69 | """ 70 | Really an AffineConstantFlow but with a data-dependent initialization, 71 | where on the very first batch we clever initialize the s,t so that the output 72 | is unit gaussian. As described in Glow paper. 73 | """ 74 | def __init__(self, *args, **kwargs): 75 | super().__init__(*args, **kwargs) 76 | self.data_dep_init_done = False 77 | 78 | def forward(self, x): 79 | # first batch is used for init 80 | if not self.data_dep_init_done: 81 | assert self.s is not None and self.t is not None # for now 82 | self.s.data = (-torch.log(x.std(dim=0, keepdim=True))).detach() 83 | self.t.data = (-(x * torch.exp(self.s)).mean(dim=0, keepdim=True)).detach() 84 | self.data_dep_init_done = True 85 | return super().forward(x) 86 | 87 | 88 | class AffineHalfFlow(nn.Module): 89 | """ 90 | As seen in RealNVP, affine autoregressive flow (z = x * exp(s) + t), where half of the 91 | dimensions in x are linearly scaled/transfromed as a function of the other half. 92 | Which half is which is determined by the parity bit. 93 | - RealNVP both scales and shifts (default) 94 | - NICE only shifts 95 | """ 96 | def __init__(self, dim, parity, net_class=MLP, nh=24, scale=True, shift=True): 97 | super().__init__() 98 | self.dim = dim 99 | self.parity = parity 100 | self.s_cond = lambda x: x.new_zeros(x.size(0), self.dim // 2) 101 | self.t_cond = lambda x: x.new_zeros(x.size(0), self.dim // 2) 102 | if scale: 103 | self.s_cond = net_class(self.dim // 2, self.dim // 2, nh) 104 | if shift: 105 | self.t_cond = net_class(self.dim // 2, self.dim // 2, nh) 106 | 107 | def forward(self, x): 108 | x0, x1 = x[:,::2], x[:,1::2] 109 | if self.parity: 110 | x0, x1 = x1, x0 111 | s = self.s_cond(x0) 112 | t = self.t_cond(x0) 113 | z0 = x0 # untouched half 114 | z1 = torch.exp(s) * x1 + t # transform this half as a function of the other 115 | if self.parity: 116 | z0, z1 = z1, z0 117 | z = torch.cat([z0, z1], dim=1) 118 | log_det = torch.sum(s, dim=1) 119 | return z, log_det 120 | 121 | def backward(self, z): 122 | z0, z1 = z[:,::2], z[:,1::2] 123 | if self.parity: 124 | z0, z1 = z1, z0 125 | s = self.s_cond(z0) 126 | t = self.t_cond(z0) 127 | x0 = z0 # this was the same 128 | x1 = (z1 - t) * torch.exp(-s) # reverse the transform on this half 129 | if self.parity: 130 | x0, x1 = x1, x0 131 | x = torch.cat([x0, x1], dim=1) 132 | log_det = torch.sum(-s, dim=1) 133 | return x, log_det 134 | 135 | 136 | class SlowMAF(nn.Module): 137 | """ 138 | Masked Autoregressive Flow, slow version with explicit networks per dim 139 | """ 140 | def __init__(self, dim, parity, net_class=MLP, nh=24): 141 | super().__init__() 142 | self.dim = dim 143 | self.layers = nn.ModuleDict() 144 | self.layers[str(0)] = LeafParam(2) 145 | for i in range(1, dim): 146 | self.layers[str(i)] = net_class(i, 2, nh) 147 | self.order = list(range(dim)) if parity else list(range(dim))[::-1] 148 | 149 | def forward(self, x): 150 | z = torch.zeros_like(x) 151 | log_det = torch.zeros(x.size(0)) 152 | for i in range(self.dim): 153 | st = self.layers[str(i)](x[:, :i]) 154 | s, t = st[:, 0], st[:, 1] 155 | z[:, self.order[i]] = x[:, i] * torch.exp(s) + t 156 | log_det += s 157 | return z, log_det 158 | 159 | def backward(self, z): 160 | x = torch.zeros_like(z) 161 | log_det = torch.zeros(z.size(0)) 162 | for i in range(self.dim): 163 | st = self.layers[str(i)](x[:, :i]) 164 | s, t = st[:, 0], st[:, 1] 165 | x[:, i] = (z[:, self.order[i]] - t) * torch.exp(-s) 166 | log_det += -s 167 | return x, log_det 168 | 169 | class MAF(nn.Module): 170 | """ Masked Autoregressive Flow that uses a MADE-style network for fast forward """ 171 | 172 | def __init__(self, dim, parity, net_class=ARMLP, nh=24): 173 | super().__init__() 174 | self.dim = dim 175 | self.net = net_class(dim, dim*2, nh) 176 | self.parity = parity 177 | 178 | def forward(self, x): 179 | # here we see that we are evaluating all of z in parallel, so density estimation will be fast 180 | st = self.net(x) 181 | s, t = st.split(self.dim, dim=1) 182 | z = x * torch.exp(s) + t 183 | # reverse order, so if we stack MAFs correct things happen 184 | z = z.flip(dims=(1,)) if self.parity else z 185 | log_det = torch.sum(s, dim=1) 186 | return z, log_det 187 | 188 | def backward(self, z): 189 | # we have to decode the x one at a time, sequentially 190 | x = torch.zeros_like(z) 191 | log_det = torch.zeros(z.size(0)) 192 | z = z.flip(dims=(1,)) if self.parity else z 193 | for i in range(self.dim): 194 | st = self.net(x.clone()) # clone to avoid in-place op errors if using IAF 195 | s, t = st.split(self.dim, dim=1) 196 | x[:, i] = (z[:, i] - t[:, i]) * torch.exp(-s[:, i]) 197 | log_det += -s[:, i] 198 | return x, log_det 199 | 200 | class IAF(MAF): 201 | def __init__(self, *args, **kwargs): 202 | super().__init__(*args, **kwargs) 203 | """ 204 | reverse the flow, giving an Inverse Autoregressive Flow (IAF) instead, 205 | where sampling will be fast but density estimation slow 206 | """ 207 | self.forward, self.backward = self.backward, self.forward 208 | 209 | 210 | class Invertible1x1Conv(nn.Module): 211 | """ 212 | As introduced in Glow paper. 213 | """ 214 | 215 | def __init__(self, dim): 216 | super().__init__() 217 | self.dim = dim 218 | Q = torch.nn.init.orthogonal_(torch.randn(dim, dim)) 219 | P, L, U = torch.lu_unpack(*Q.lu()) 220 | self.P = P # remains fixed during optimization 221 | self.L = nn.Parameter(L) # lower triangular portion 222 | self.S = nn.Parameter(U.diag()) # "crop out" the diagonal to its own parameter 223 | self.U = nn.Parameter(torch.triu(U, diagonal=1)) # "crop out" diagonal, stored in S 224 | 225 | def _assemble_W(self): 226 | """ assemble W from its pieces (P, L, U, S) """ 227 | L = torch.tril(self.L, diagonal=-1) + torch.diag(torch.ones(self.dim)) 228 | U = torch.triu(self.U, diagonal=1) 229 | W = self.P @ L @ (U + torch.diag(self.S)) 230 | return W 231 | 232 | def forward(self, x): 233 | W = self._assemble_W() 234 | z = x @ W 235 | log_det = torch.sum(torch.log(torch.abs(self.S))) 236 | return z, log_det 237 | 238 | def backward(self, z): 239 | W = self._assemble_W() 240 | W_inv = torch.inverse(W) 241 | x = z @ W_inv 242 | log_det = -torch.sum(torch.log(torch.abs(self.S))) 243 | return x, log_det 244 | 245 | # ------------------------------------------------------------------------ 246 | 247 | class NormalizingFlow(nn.Module): 248 | """ A sequence of Normalizing Flows is a Normalizing Flow """ 249 | 250 | def __init__(self, flows): 251 | super().__init__() 252 | self.flows = nn.ModuleList(flows) 253 | 254 | def forward(self, x): 255 | m, _ = x.shape 256 | log_det = torch.zeros(m) 257 | zs = [x] 258 | for flow in self.flows: 259 | x, ld = flow.forward(x) 260 | log_det += ld 261 | zs.append(x) 262 | return zs, log_det 263 | 264 | def backward(self, z): 265 | m, _ = z.shape 266 | log_det = torch.zeros(m) 267 | xs = [z] 268 | for flow in self.flows[::-1]: 269 | z, ld = flow.backward(z) 270 | log_det += ld 271 | xs.append(z) 272 | return xs, log_det 273 | 274 | class NormalizingFlowModel(nn.Module): 275 | """ A Normalizing Flow Model is a (prior, flow) pair """ 276 | 277 | def __init__(self, prior, flows): 278 | super().__init__() 279 | self.prior = prior 280 | self.flow = NormalizingFlow(flows) 281 | 282 | def forward(self, x): 283 | zs, log_det = self.flow.forward(x) 284 | prior_logprob = self.prior.log_prob(zs[-1]).view(x.size(0), -1).sum(1) 285 | return zs, prior_logprob, log_det 286 | 287 | def backward(self, z): 288 | xs, log_det = self.flow.backward(z) 289 | return xs, log_det 290 | 291 | def sample(self, num_samples): 292 | z = self.prior.sample((num_samples,)) 293 | xs, _ = self.flow.backward(z) 294 | return xs 295 | -------------------------------------------------------------------------------- /nflib/made.py: -------------------------------------------------------------------------------- 1 | """ 2 | # copy pasted from my earlier MADE implementation 3 | # https://github.com/karpathy/pytorch-made 4 | 5 | Implements a Masked Autoregressive MLP, where carefully constructed 6 | binary masks over weights ensure the autoregressive property. 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | from torch import nn 13 | 14 | class MaskedLinear(nn.Linear): 15 | """ same as Linear except has a configurable mask on the weights """ 16 | 17 | def __init__(self, in_features, out_features, bias=True): 18 | super().__init__(in_features, out_features, bias) 19 | self.register_buffer('mask', torch.ones(out_features, in_features)) 20 | 21 | def set_mask(self, mask): 22 | self.mask.data.copy_(torch.from_numpy(mask.astype(np.uint8).T)) 23 | 24 | def forward(self, input): 25 | return F.linear(input, self.mask * self.weight, self.bias) 26 | 27 | class MADE(nn.Module): 28 | def __init__(self, nin, hidden_sizes, nout, num_masks=1, natural_ordering=False): 29 | """ 30 | nin: integer; number of inputs 31 | hidden sizes: a list of integers; number of units in hidden layers 32 | nout: integer; number of outputs, which usually collectively parameterize some kind of 1D distribution 33 | note: if nout is e.g. 2x larger than nin (perhaps the mean and std), then the first nin 34 | will be all the means and the second nin will be stds. i.e. output dimensions depend on the 35 | same input dimensions in "chunks" and should be carefully decoded downstream appropriately. 36 | the output of running the tests for this file makes this a bit more clear with examples. 37 | num_masks: can be used to train ensemble over orderings/connections 38 | natural_ordering: force natural ordering of dimensions, don't use random permutations 39 | """ 40 | 41 | super().__init__() 42 | self.nin = nin 43 | self.nout = nout 44 | self.hidden_sizes = hidden_sizes 45 | assert self.nout % self.nin == 0, "nout must be integer multiple of nin" 46 | 47 | # define a simple MLP neural net 48 | self.net = [] 49 | hs = [nin] + hidden_sizes + [nout] 50 | for h0,h1 in zip(hs, hs[1:]): 51 | self.net.extend([ 52 | MaskedLinear(h0, h1), 53 | nn.ReLU(), 54 | ]) 55 | self.net.pop() # pop the last ReLU for the output layer 56 | self.net = nn.Sequential(*self.net) 57 | 58 | # seeds for orders/connectivities of the model ensemble 59 | self.natural_ordering = natural_ordering 60 | self.num_masks = num_masks 61 | self.seed = 0 # for cycling through num_masks orderings 62 | 63 | self.m = {} 64 | self.update_masks() # builds the initial self.m connectivity 65 | # note, we could also precompute the masks and cache them, but this 66 | # could get memory expensive for large number of masks. 67 | 68 | def update_masks(self): 69 | if self.m and self.num_masks == 1: return # only a single seed, skip for efficiency 70 | L = len(self.hidden_sizes) 71 | 72 | # fetch the next seed and construct a random stream 73 | rng = np.random.RandomState(self.seed) 74 | self.seed = (self.seed + 1) % self.num_masks 75 | 76 | # sample the order of the inputs and the connectivity of all neurons 77 | self.m[-1] = np.arange(self.nin) if self.natural_ordering else rng.permutation(self.nin) 78 | for l in range(L): 79 | self.m[l] = rng.randint(self.m[l-1].min(), self.nin-1, size=self.hidden_sizes[l]) 80 | 81 | # construct the mask matrices 82 | masks = [self.m[l-1][:,None] <= self.m[l][None,:] for l in range(L)] 83 | masks.append(self.m[L-1][:,None] < self.m[-1][None,:]) 84 | 85 | # handle the case where nout = nin * k, for integer k > 1 86 | if self.nout > self.nin: 87 | k = int(self.nout / self.nin) 88 | # replicate the mask across the other outputs 89 | masks[-1] = np.concatenate([masks[-1]]*k, axis=1) 90 | 91 | # set the masks in all MaskedLinear layers 92 | layers = [l for l in self.net.modules() if isinstance(l, MaskedLinear)] 93 | for l,m in zip(layers, masks): 94 | l.set_mask(m) 95 | 96 | def forward(self, x): 97 | return self.net(x) 98 | -------------------------------------------------------------------------------- /nflib/nets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various helper network modules 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from nflib.made import MADE 10 | 11 | class LeafParam(nn.Module): 12 | """ 13 | just ignores the input and outputs a parameter tensor, lol 14 | todo maybe this exists in PyTorch somewhere? 15 | """ 16 | def __init__(self, n): 17 | super().__init__() 18 | self.p = nn.Parameter(torch.zeros(1,n)) 19 | 20 | def forward(self, x): 21 | return self.p.expand(x.size(0), self.p.size(1)) 22 | 23 | class PositionalEncoder(nn.Module): 24 | """ 25 | Each dimension of the input gets expanded out with sins/coses 26 | to "carve" out the space. Useful in low-dimensional cases with 27 | tightly "curled up" data. 28 | """ 29 | def __init__(self, freqs=(.5,1,2,4,8)): 30 | super().__init__() 31 | self.freqs = freqs 32 | 33 | def forward(self, x): 34 | sines = [torch.sin(x * f) for f in self.freqs] 35 | coses = [torch.cos(x * f) for f in self.freqs] 36 | out = torch.cat(sines + coses, dim=1) 37 | return out 38 | 39 | class MLP(nn.Module): 40 | """ a simple 4-layer MLP """ 41 | 42 | def __init__(self, nin, nout, nh): 43 | super().__init__() 44 | self.net = nn.Sequential( 45 | nn.Linear(nin, nh), 46 | nn.LeakyReLU(0.2), 47 | nn.Linear(nh, nh), 48 | nn.LeakyReLU(0.2), 49 | nn.Linear(nh, nh), 50 | nn.LeakyReLU(0.2), 51 | nn.Linear(nh, nout), 52 | ) 53 | def forward(self, x): 54 | return self.net(x) 55 | 56 | class PosEncMLP(nn.Module): 57 | """ 58 | Position Encoded MLP, where the first layer performs position encoding. 59 | Each dimension of the input gets transformed to len(freqs)*2 dimensions 60 | using a fixed transformation of sin/cos of given frequencies. 61 | """ 62 | def __init__(self, nin, nout, nh, freqs=(.5,1,2,4,8)): 63 | super().__init__() 64 | self.net = nn.Sequential( 65 | PositionalEncoder(freqs), 66 | MLP(nin * len(freqs) * 2, nout, nh), 67 | ) 68 | def forward(self, x): 69 | return self.net(x) 70 | 71 | class ARMLP(nn.Module): 72 | """ a 4-layer auto-regressive MLP, wrapper around MADE net """ 73 | 74 | def __init__(self, nin, nout, nh): 75 | super().__init__() 76 | self.net = MADE(nin, [nh, nh, nh], nout, num_masks=1, natural_ordering=True) 77 | 78 | def forward(self, x): 79 | return self.net(x) 80 | -------------------------------------------------------------------------------- /nflib/spline_flows.py: -------------------------------------------------------------------------------- 1 | """ 2 | Neural Spline Flows, coupling and autoregressive 3 | 4 | Paper reference: Durkan et al https://arxiv.org/abs/1906.04032 5 | Code reference: slightly modified https://github.com/tonyduan/normalizing-flows/blob/master/nf/flows.py 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.init as init 12 | import torch.nn.functional as F 13 | 14 | from nflib.nets import MLP 15 | 16 | DEFAULT_MIN_BIN_WIDTH = 1e-3 17 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 18 | DEFAULT_MIN_DERIVATIVE = 1e-3 19 | 20 | def searchsorted(bin_locations, inputs, eps=1e-6): 21 | bin_locations[..., -1] += eps 22 | return torch.sum( 23 | inputs[..., None] >= bin_locations, 24 | dim=-1 25 | ) - 1 26 | 27 | def unconstrained_RQS(inputs, unnormalized_widths, unnormalized_heights, 28 | unnormalized_derivatives, inverse=False, 29 | tail_bound=1., min_bin_width=DEFAULT_MIN_BIN_WIDTH, 30 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 31 | min_derivative=DEFAULT_MIN_DERIVATIVE): 32 | inside_intvl_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 33 | outside_interval_mask = ~inside_intvl_mask 34 | 35 | outputs = torch.zeros_like(inputs) 36 | logabsdet = torch.zeros_like(inputs) 37 | 38 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 39 | constant = np.log(np.exp(1 - min_derivative) - 1) 40 | unnormalized_derivatives[..., 0] = constant 41 | unnormalized_derivatives[..., -1] = constant 42 | 43 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 44 | logabsdet[outside_interval_mask] = 0 45 | 46 | outputs[inside_intvl_mask], logabsdet[inside_intvl_mask] = RQS( 47 | inputs=inputs[inside_intvl_mask], 48 | unnormalized_widths=unnormalized_widths[inside_intvl_mask, :], 49 | unnormalized_heights=unnormalized_heights[inside_intvl_mask, :], 50 | unnormalized_derivatives=unnormalized_derivatives[inside_intvl_mask, :], 51 | inverse=inverse, 52 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 53 | min_bin_width=min_bin_width, 54 | min_bin_height=min_bin_height, 55 | min_derivative=min_derivative 56 | ) 57 | return outputs, logabsdet 58 | 59 | def RQS(inputs, unnormalized_widths, unnormalized_heights, 60 | unnormalized_derivatives, inverse=False, left=0., right=1., 61 | bottom=0., top=1., min_bin_width=DEFAULT_MIN_BIN_WIDTH, 62 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 63 | min_derivative=DEFAULT_MIN_DERIVATIVE): 64 | if torch.min(inputs) < left or torch.max(inputs) > right: 65 | raise ValueError("Input outside domain") 66 | 67 | num_bins = unnormalized_widths.shape[-1] 68 | 69 | if min_bin_width * num_bins > 1.0: 70 | raise ValueError('Minimal bin width too large for the number of bins') 71 | if min_bin_height * num_bins > 1.0: 72 | raise ValueError('Minimal bin height too large for the number of bins') 73 | 74 | widths = F.softmax(unnormalized_widths, dim=-1) 75 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 76 | cumwidths = torch.cumsum(widths, dim=-1) 77 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 78 | cumwidths = (right - left) * cumwidths + left 79 | cumwidths[..., 0] = left 80 | cumwidths[..., -1] = right 81 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 82 | 83 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 84 | 85 | heights = F.softmax(unnormalized_heights, dim=-1) 86 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 87 | cumheights = torch.cumsum(heights, dim=-1) 88 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 89 | cumheights = (top - bottom) * cumheights + bottom 90 | cumheights[..., 0] = bottom 91 | cumheights[..., -1] = top 92 | heights = cumheights[..., 1:] - cumheights[..., :-1] 93 | 94 | if inverse: 95 | bin_idx = searchsorted(cumheights, inputs)[..., None] 96 | else: 97 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 98 | 99 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 100 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 101 | 102 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 103 | delta = heights / widths 104 | input_delta = delta.gather(-1, bin_idx)[..., 0] 105 | 106 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 107 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx) 108 | input_derivatives_plus_one = input_derivatives_plus_one[..., 0] 109 | 110 | input_heights = heights.gather(-1, bin_idx)[..., 0] 111 | 112 | if inverse: 113 | a = (((inputs - input_cumheights) * (input_derivatives \ 114 | + input_derivatives_plus_one - 2 * input_delta) \ 115 | + input_heights * (input_delta - input_derivatives))) 116 | b = (input_heights * input_derivatives - (inputs - input_cumheights) \ 117 | * (input_derivatives + input_derivatives_plus_one \ 118 | - 2 * input_delta)) 119 | c = - input_delta * (inputs - input_cumheights) 120 | 121 | discriminant = b.pow(2) - 4 * a * c 122 | assert (discriminant >= 0).all() 123 | 124 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 125 | outputs = root * input_bin_widths + input_cumwidths 126 | 127 | theta_one_minus_theta = root * (1 - root) 128 | denominator = input_delta \ 129 | + ((input_derivatives + input_derivatives_plus_one \ 130 | - 2 * input_delta) * theta_one_minus_theta) 131 | derivative_numerator = input_delta.pow(2) \ 132 | * (input_derivatives_plus_one * root.pow(2) \ 133 | + 2 * input_delta * theta_one_minus_theta \ 134 | + input_derivatives * (1 - root).pow(2)) 135 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 136 | return outputs, -logabsdet 137 | else: 138 | theta = (inputs - input_cumwidths) / input_bin_widths 139 | theta_one_minus_theta = theta * (1 - theta) 140 | 141 | numerator = input_heights * (input_delta * theta.pow(2) \ 142 | + input_derivatives * theta_one_minus_theta) 143 | denominator = input_delta + ((input_derivatives \ 144 | + input_derivatives_plus_one - 2 * input_delta) \ 145 | * theta_one_minus_theta) 146 | outputs = input_cumheights + numerator / denominator 147 | 148 | derivative_numerator = input_delta.pow(2) \ 149 | * (input_derivatives_plus_one * theta.pow(2) \ 150 | + 2 * input_delta * theta_one_minus_theta \ 151 | + input_derivatives * (1 - theta).pow(2)) 152 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 153 | return outputs, logabsdet 154 | 155 | class NSF_AR(nn.Module): 156 | """ Neural spline flow, coupling layer, [Durkan et al. 2019] """ 157 | 158 | def __init__(self, dim, K=5, B=3, hidden_dim=8, base_network=MLP): 159 | super().__init__() 160 | self.dim = dim 161 | self.K = K 162 | self.B = B 163 | self.layers = nn.ModuleList() 164 | self.init_param = nn.Parameter(torch.Tensor(3 * K - 1)) 165 | for i in range(1, dim): 166 | self.layers += [base_network(i, 3 * K - 1, hidden_dim)] 167 | self.reset_parameters() 168 | 169 | def reset_parameters(self): 170 | init.uniform_(self.init_param, - 1 / 2, 1 / 2) 171 | 172 | def forward(self, x): 173 | z = torch.zeros_like(x) 174 | log_det = torch.zeros(z.shape[0]) 175 | for i in range(self.dim): 176 | if i == 0: 177 | init_param = self.init_param.expand(x.shape[0], 3 * self.K - 1) 178 | W, H, D = torch.split(init_param, self.K, dim = 1) 179 | else: 180 | out = self.layers[i - 1](x[:, :i]) 181 | W, H, D = torch.split(out, self.K, dim = 1) 182 | W, H = torch.softmax(W, dim = 1), torch.softmax(H, dim = 1) 183 | W, H = 2 * self.B * W, 2 * self.B * H 184 | D = F.softplus(D) 185 | z[:, i], ld = unconstrained_RQS(x[:, i], W, H, D, inverse=False, tail_bound=self.B) 186 | log_det += ld 187 | return z, log_det 188 | 189 | def backward(self, z): 190 | x = torch.zeros_like(z) 191 | log_det = torch.zeros(x.shape[0]) 192 | for i in range(self.dim): 193 | if i == 0: 194 | init_param = self.init_param.expand(x.shape[0], 3 * self.K - 1) 195 | W, H, D = torch.split(init_param, self.K, dim = 1) 196 | else: 197 | out = self.layers[i - 1](x[:, :i]) 198 | W, H, D = torch.split(out, self.K, dim = 1) 199 | W, H = torch.softmax(W, dim = 1), torch.softmax(H, dim = 1) 200 | W, H = 2 * self.B * W, 2 * self.B * H 201 | D = F.softplus(D) 202 | x[:, i], ld = unconstrained_RQS(z[:, i], W, H, D, inverse = True, tail_bound = self.B) 203 | log_det += ld 204 | return x, log_det 205 | 206 | 207 | class NSF_CL(nn.Module): 208 | """ Neural spline flow, coupling layer, [Durkan et al. 2019] """ 209 | 210 | def __init__(self, dim, K=5, B=3, hidden_dim=8, base_network=MLP): 211 | super().__init__() 212 | self.dim = dim 213 | self.K = K 214 | self.B = B 215 | self.f1 = base_network(dim // 2, (3 * K - 1) * dim // 2, hidden_dim) 216 | self.f2 = base_network(dim // 2, (3 * K - 1) * dim // 2, hidden_dim) 217 | 218 | def forward(self, x): 219 | log_det = torch.zeros(x.shape[0]) 220 | lower, upper = x[:, :self.dim // 2], x[:, self.dim // 2:] 221 | out = self.f1(lower).reshape(-1, self.dim // 2, 3 * self.K - 1) 222 | W, H, D = torch.split(out, self.K, dim = 2) 223 | W, H = torch.softmax(W, dim = 2), torch.softmax(H, dim = 2) 224 | W, H = 2 * self.B * W, 2 * self.B * H 225 | D = F.softplus(D) 226 | upper, ld = unconstrained_RQS(upper, W, H, D, inverse=False, tail_bound=self.B) 227 | log_det += torch.sum(ld, dim = 1) 228 | out = self.f2(upper).reshape(-1, self.dim // 2, 3 * self.K - 1) 229 | W, H, D = torch.split(out, self.K, dim = 2) 230 | W, H = torch.softmax(W, dim = 2), torch.softmax(H, dim = 2) 231 | W, H = 2 * self.B * W, 2 * self.B * H 232 | D = F.softplus(D) 233 | lower, ld = unconstrained_RQS(lower, W, H, D, inverse=False, tail_bound=self.B) 234 | log_det += torch.sum(ld, dim = 1) 235 | return torch.cat([lower, upper], dim = 1), log_det 236 | 237 | def backward(self, z): 238 | log_det = torch.zeros(z.shape[0]) 239 | lower, upper = z[:, :self.dim // 2], z[:, self.dim // 2:] 240 | out = self.f2(upper).reshape(-1, self.dim // 2, 3 * self.K - 1) 241 | W, H, D = torch.split(out, self.K, dim = 2) 242 | W, H = torch.softmax(W, dim = 2), torch.softmax(H, dim = 2) 243 | W, H = 2 * self.B * W, 2 * self.B * H 244 | D = F.softplus(D) 245 | lower, ld = unconstrained_RQS(lower, W, H, D, inverse=True, tail_bound=self.B) 246 | log_det += torch.sum(ld, dim = 1) 247 | out = self.f1(lower).reshape(-1, self.dim // 2, 3 * self.K - 1) 248 | W, H, D = torch.split(out, self.K, dim = 2) 249 | W, H = torch.softmax(W, dim = 2), torch.softmax(H, dim = 2) 250 | W, H = 2 * self.B * W, 2 * self.B * H 251 | D = F.softplus(D) 252 | upper, ld = unconstrained_RQS(upper, W, H, D, inverse = True, tail_bound = self.B) 253 | log_det += torch.sum(ld, dim = 1) 254 | return torch.cat([lower, upper], dim = 1), log_det --------------------------------------------------------------------------------