├── .gitignore ├── LICENSE ├── README.md ├── examples ├── __init__.py ├── ex_1d.png ├── ex_1d.py ├── ex_2d.png └── ex_2d.py └── src ├── flows.py ├── models.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Tony Duan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Normalizing Flows Models 2 | 3 | Last update: December 2022. 4 | 5 | --- 6 | 7 | Lightweight normalizing flows for generative modeling in PyTorch. 8 | 9 | #### Setup 10 | 11 | ```math 12 | \begin{align*} 13 | \mathbf{x} & = f_\theta^{-1}(\mathbf{z}) & \mathbf{z} & = f_\theta(\mathbf{x}), 14 | \end{align*} 15 | ``` 16 | 17 | where $f:\mathbb{R}^d \mapsto \mathbb{R}^d$ is an invertible function. The Change of Variables formula tells us that 18 | ```math 19 | \begin{align*} 20 | \underbrace{p(\mathbf{x})}_{\text{over }\mathbf{x}} &= \underbrace{p\left(f_\theta(\mathbf{x})\right)}_{\text{over }\mathbf{z}} \left|\mathrm{det}\left(\frac{\partial f_\theta \mathbf{x}}{\partial \mathbf{x}}\right)\right| 21 | \end{align*} 22 | ``` 23 | 24 | Here $\frac{\partial f_\theta\mathbf{x}}{\partial \mathbf{x}}$ denotes the $d \times d$ Jacobian (this needs to be easy to compute). 25 | 26 | We typically choose a simple distribution over the latent space, $p(\mathbf{z})\sim N(\mathbf{0},\mathbf{I})$. 27 | 28 | Suppose we compose functions $f_\theta(\mathbf{x}) = f_1\circ f_2 \circ \dots f_k(\mathbf{x};\theta)$. The log-likelihood decomposes nicely. 29 | ```math 30 | \begin{align*} 31 | \log p(\mathbf{x}) & = \log p\left(f_\theta(\mathbf{x}\right)) + \sum_{i=1}^k\log\left|\mathrm{det}\frac{\partial f_i(\mathbf{x};\theta)}{\partial \mathbf{x}}\right| 32 | \end{align*} 33 | ``` 34 | Sampling can be done easily, as long as $f_\theta^{-1}$ is tractable. 35 | 36 | #### Implemented Flows 37 | 38 | **Planar and radial flows** [1]. Note these have no algebraic inverse $f^{-1}(\mathbf{x})$. 39 | ```math 40 | \begin{align*} 41 | f(\mathbf{x}) & = \mathbf{x} + \mathbf{u}h(\mathbf{w}^\top \mathbf{z} + b)\\ 42 | f(\mathbf{x}) & = \mathbf{x} + \frac{\beta(\mathbf{x}-\mathbf{x}_0)}{\alpha + \|\mathbf{x}-\mathbf{x}_0\|} 43 | \end{align*} 44 | ``` 45 | **Real NVP** [2]. Partition the vector $\mathbf{x}$ into components $\mathbf{x}^{(1)},\mathbf{x}^{(2)}$. Let $s,t$ be arbitrary neural networks $\mathbb{R}^d \mapsto \mathbb{R}^d$. 46 | ```math 47 | \begin{align*} 48 | f(\mathbf{x}^{(1)}) &= t(\mathbf{x}^{(2)}) + \mathbf{x}^{(1)}\odot \exp s(\mathbf{x}^{(2)})\\ 49 | f(\mathbf{x}^{(2)}) &= t(\mathbf{x}^{(1)}) + \mathbf{x}^{(2)}\odot \exp s(\mathbf{x}^{(1)}) 50 | \end{align*} 51 | ``` 52 | Here the diagonal of the Jacobian is simply $[\exp s(\mathbf{x}^{(2)}) \exp s(\mathbf{x}^{(1)})]$. 53 | 54 | **Invertible 1x1 Convolution** [3]. Use an LU decomposition for computational efficiency. 55 | ```math 56 | f(\mathbf{x})= W\mathbf{x}, \text{ where }W \text{ is square} 57 | ``` 58 | **ActNorm** [3]. Even more straightforward. 59 | ```math 60 | f(\mathbf{x}) = W\mathbf{x} + b, \text{ where }W \text{ is diagonal} 61 | ``` 62 | **Masked Autoregressive Flow** [4]. For each dimension of $\mathbf{x}$, use a neural network to predict scalars $\mu,\alpha$. 63 | ```math 64 | f(x_i) = (x_i - \mu(x_{< i})) / \mathrm{exp}(\alpha(x_{< i})) 65 | ``` 66 | Here the diagonal of the Jacobian is $\exp^{-1}(\alpha)$. 67 | 68 | **Neural Spline Flow** [5]. Two versions: auto-regressive and coupling. 69 | ```math 70 | \begin{align*} 71 | f(x_i) & = \mathrm{RQS}_{g(x_{< i})}(x_i), \text{ (autoregressive) }\\ 72 | f(\mathbf{x}^{(1)}) & = \mathrm{RQS}_{g(\mathbf{x}^{(2)})}(\mathbf{x}^{(1)}) \text{ (coupling)}\\ 73 | f(\mathbf{x}^{(2)}) & = \mathrm{RQS}_{g(\mathbf{x}^{(1)})}(\mathbf{x}^{(2)}) 74 | \end{align*} 75 | ``` 76 | #### Examples 77 | 78 | Below we show examples (in 1D and 2D) transforming a mixture of Gaussians into a unit Gaussian. 79 | 80 | ![](examples/ex_1d.png) 81 | 82 | ![](examples/ex_2d.png) 83 | 84 | #### References 85 | 86 | [1] Rezende, D. J. & Mohamed, S. Variational Inference with Normalizing Flows. in Proceedings of the 32nd International Conference on Machine Learning - Volume 37 - Volume 37 1530–1538 (JMLR.org, 2015). 87 | 88 | [2] Dinh, L., Krueger, D., and Bengio, Y. (2014). NICE: Non-linear Independent Components Estimation. 89 | 90 | [3] Kingma, D.P., and Dhariwal, P. (2018). Glow: Generative Flow with Invertible 1x1 Convolutions. In Advances in Neural Information Processing Systems 31, S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, eds. (Curran Associates, Inc.), pp. 10215–10224. 91 | 92 | [4] Papamakarios, G., Pavlakou, T., and Murray, I. (2017). Masked Autoregressive Flow for Density Estimation. In Advances in Neural Information Processing Systems 30, I. Guyon, U.V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, eds. (Curran Associates, Inc.), pp. 2338–2347. 93 | 94 | [5] Durkan, C., Bekasov, A., Murray, I., and Papamakarios, G. (2019). Neural Spline Flows. 95 | 96 | #### License 97 | 98 | This code is available under the MIT License. 99 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/normalizing-flows/3ee9127e3753be3878a0a4fd3ea435ff2d31d143/examples/__init__.py -------------------------------------------------------------------------------- /examples/ex_1d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/normalizing-flows/3ee9127e3753be3878a0a4fd3ea435ff2d31d143/examples/ex_1d.png -------------------------------------------------------------------------------- /examples/ex_1d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy as sp 3 | import scipy.stats 4 | import itertools 5 | import logging 6 | 7 | import matplotlib.pyplot as plt 8 | import torch 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | from argparse import ArgumentParser 12 | 13 | from src.flows import * 14 | from src.models import NormalizingFlowModel 15 | 16 | 17 | def gen_data(n=512): 18 | return np.r_[np.random.randn(n // 2, 1) + np.array([2]), 19 | np.random.randn(n // 2, 1) + np.array([-2])] 20 | 21 | def plot_data(x, bandwidth = 0.2, **kwargs): 22 | kde = sp.stats.gaussian_kde(x[:,0]) 23 | x_axis = np.linspace(-5, 5, 200) 24 | plt.plot(x_axis, kde(x_axis), **kwargs) 25 | plt.axis("off") 26 | 27 | 28 | if __name__ == "__main__": 29 | 30 | argparser = ArgumentParser() 31 | argparser.add_argument("--n", default=512, type=int) 32 | argparser.add_argument("--flows", default=2, type=int) 33 | argparser.add_argument("--flow", default="NSF_AR", type=str) 34 | argparser.add_argument("--iterations", default=500, type=int) 35 | args = argparser.parse_args() 36 | 37 | logging.basicConfig(level=logging.INFO) 38 | logger = logging.getLogger(__name__) 39 | 40 | flow = eval(args.flow) 41 | flows = [flow(dim=1) for _ in range(args.flows)] 42 | model = NormalizingFlowModel(dim=1, flows=flows) 43 | 44 | optimizer = optim.Adam(model.parameters(), lr=0.005) 45 | x = torch.Tensor(gen_data(args.n)) 46 | 47 | for i in range(x.shape[1]): 48 | x[:,i] = (x[:,i] - torch.mean(x[:,i])) / torch.std(x[:,i]) 49 | 50 | for i in range(args.iterations): 51 | optimizer.zero_grad() 52 | z, prior_logprob, log_det = model(x) 53 | logprob = prior_logprob + log_det 54 | loss = -torch.mean(prior_logprob + log_det) 55 | loss.backward() 56 | optimizer.step() 57 | if i % 100 == 0: 58 | logger.info(f"Iter: {i}\t" + 59 | f"Logprob: {logprob.mean().data:.2f}\t" + 60 | f"Prior: {prior_logprob.mean().data:.2f}\t" + 61 | f"LogDet: {log_det.mean().data:.2f}") 62 | 63 | plt.figure(figsize=(8, 3)) 64 | plt.subplot(1, 3, 1) 65 | plot_data(x, color="black", alpha=0.5) 66 | plt.title("Training data") 67 | plt.subplot(1, 3, 2) 68 | plot_data(z.data, color="darkblue", alpha=0.5) 69 | plt.title("Latent space") 70 | plt.subplot(1, 3, 3) 71 | samples = model.sample(500).data 72 | plot_data(samples, color="black", alpha=0.5) 73 | plt.title("Generated samples") 74 | plt.savefig("./examples/ex_1d.png") 75 | plt.show() 76 | -------------------------------------------------------------------------------- /examples/ex_2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/normalizing-flows/3ee9127e3753be3878a0a4fd3ea435ff2d31d143/examples/ex_2d.png -------------------------------------------------------------------------------- /examples/ex_2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | import logging 4 | 5 | import matplotlib.pyplot as plt 6 | import torch 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | from argparse import ArgumentParser 10 | 11 | from src.flows import * 12 | from src.models import NormalizingFlowModel 13 | 14 | 15 | def gen_data(n=512): 16 | return np.r_[np.random.randn(n // 3, 2) + np.array([0, 6]), 17 | np.random.randn(n // 3, 2) + np.array([2.5, 3]), 18 | np.random.randn(n // 3, 2) + np.array([-2.5, 3])] 19 | 20 | 21 | def gen_mixture_data(n=512): 22 | return np.r_[np.random.randn(n // 2, 2) + np.array([5, 3]), 23 | np.random.randn(n // 2, 2) + np.array([-5, 3])] 24 | 25 | 26 | 27 | def plot_data(x, **kwargs): 28 | plt.scatter(x[:,0], x[:,1], marker="x", **kwargs) 29 | plt.xlim((-3, 3)) 30 | plt.ylim((-3, 3)) 31 | plt.axis("off") 32 | 33 | 34 | if __name__ == "__main__": 35 | 36 | argparser = ArgumentParser() 37 | argparser.add_argument("--n", default=512, type=int) 38 | argparser.add_argument("--flows", default=2, type=int) 39 | argparser.add_argument("--flow", default="NSF_CL", type=str) 40 | argparser.add_argument("--iterations", default=500, type=int) 41 | argparser.add_argument("--use-mixture", action="store_true") 42 | argparser.add_argument("--convolve", action="store_true") 43 | argparser.add_argument("--actnorm", action="store_true") 44 | args = argparser.parse_args() 45 | 46 | logging.basicConfig(level=logging.INFO) 47 | logger = logging.getLogger(__name__) 48 | 49 | flow = eval(args.flow) 50 | flows = [flow(dim=2) for _ in range(args.flows)] 51 | if args.convolve: 52 | convs = [OneByOneConv(dim=2) for _ in range(args.flows)] 53 | flows = list(itertools.chain(*zip(convs, flows))) 54 | if args.actnorm: 55 | actnorms = [ActNorm(dim=2) for _ in range(args.flows)] 56 | flows = list(itertools.chain(*zip(actnorms, flows))) 57 | 58 | model = NormalizingFlowModel(dim=2, flows=flows) 59 | 60 | optimizer = optim.Adam(model.parameters(), lr=0.005) 61 | if args.use_mixture: 62 | x = torch.Tensor(gen_mixture_data(args.n)) 63 | else: 64 | x = torch.Tensor(gen_data(args.n)) 65 | 66 | for i in range(x.shape[1]): 67 | x[:,i] = (x[:,i] - torch.mean(x[:,i])) / torch.std(x[:,i]) 68 | 69 | for i in range(args.iterations): 70 | optimizer.zero_grad() 71 | z, prior_logprob, log_det = model(x) 72 | logprob = prior_logprob + log_det 73 | loss = -torch.mean(prior_logprob + log_det) 74 | loss.backward() 75 | optimizer.step() 76 | if i % 100 == 0: 77 | logger.info(f"Iter: {i}\t" + 78 | f"Logprob: {logprob.mean().data:.2f}\t" + 79 | f"Prior: {prior_logprob.mean().data:.2f}\t" + 80 | f"LogDet: {log_det.mean().data:.2f}") 81 | 82 | plt.figure(figsize=(8, 3)) 83 | plt.subplot(1, 3, 1) 84 | plot_data(x, color="black", alpha=0.5) 85 | plt.title("Training data") 86 | plt.subplot(1, 3, 2) 87 | plot_data(z.data, color="darkblue", alpha=0.5) 88 | plt.title("Latent space") 89 | plt.subplot(1, 3, 3) 90 | samples = model.sample(500).data 91 | plot_data(samples, color="black", alpha=0.5) 92 | plt.title("Generated samples") 93 | plt.savefig("./examples/ex_2d.png") 94 | plt.show() 95 | 96 | for f in flows: 97 | x = f(x)[0].data 98 | plot_data(x, color="black", alpha=0.5) 99 | plt.show() 100 | 101 | -------------------------------------------------------------------------------- /src/flows.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import scipy as sp 5 | import scipy.linalg 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | 11 | from src.utils import unconstrained_RQS 12 | 13 | 14 | # supported non-linearities: note that the function must be invertible 15 | functional_derivatives = { 16 | torch.tanh: lambda x: 1 - torch.pow(torch.tanh(x), 2), 17 | F.leaky_relu: lambda x: (x > 0).type(torch.FloatTensor) + \ 18 | (x < 0).type(torch.FloatTensor) * -0.01, 19 | F.elu: lambda x: (x > 0).type(torch.FloatTensor) + \ 20 | (x < 0).type(torch.FloatTensor) * torch.exp(x) 21 | } 22 | 23 | 24 | class Planar(nn.Module): 25 | """ 26 | Planar flow. 27 | 28 | z = f(x) = x + u h(wᵀx + b) 29 | 30 | [Rezende and Mohamed, 2015] 31 | """ 32 | def __init__(self, dim, nonlinearity=torch.tanh): 33 | super().__init__() 34 | self.h = nonlinearity 35 | self.w = nn.Parameter(torch.Tensor(dim)) 36 | self.u = nn.Parameter(torch.Tensor(dim)) 37 | self.b = nn.Parameter(torch.Tensor(1)) 38 | self.reset_parameters(dim) 39 | 40 | def reset_parameters(self, dim): 41 | init.uniform_(self.w, -math.sqrt(1/dim), math.sqrt(1/dim)) 42 | init.uniform_(self.u, -math.sqrt(1/dim), math.sqrt(1/dim)) 43 | init.uniform_(self.b, -math.sqrt(1/dim), math.sqrt(1/dim)) 44 | 45 | def forward(self, x): 46 | """ 47 | Given x, returns z and the log-determinant log|df/dx|. 48 | 49 | Returns 50 | ------- 51 | """ 52 | if self.h in (F.elu, F.leaky_relu): 53 | u = self.u 54 | elif self.h == torch.tanh: 55 | scal = torch.log(1+torch.exp(self.w @ self.u)) - self.w @ self.u - 1 56 | u = self.u + scal * self.w / torch.norm(self.w) ** 2 57 | else: 58 | raise NotImplementedError("Non-linearity is not supported.") 59 | lin = torch.unsqueeze(x @ self.w, 1) + self.b 60 | z = x + u * self.h(lin) 61 | phi = functional_derivatives[self.h](lin) * self.w 62 | log_det = torch.log(torch.abs(1 + phi @ u) + 1e-4) 63 | return z, log_det 64 | 65 | def inverse(self, z): 66 | raise NotImplementedError("Planar flow has no algebraic inverse.") 67 | 68 | 69 | class Radial(nn.Module): 70 | """ 71 | Radial flow. 72 | 73 | z = f(x) = = x + β h(α, r)(z − z0) 74 | 75 | [Rezende and Mohamed 2015] 76 | """ 77 | def __init__(self, dim): 78 | super().__init__() 79 | self.x0 = nn.Parameter(torch.Tensor(dim)) 80 | self.log_alpha = nn.Parameter(torch.Tensor(1)) 81 | self.beta = nn.Parameter(torch.Tensor(1)) 82 | 83 | def reset_parameters(dim): 84 | init.uniform_(self.z0, -math.sqrt(1/dim), math.sqrt(1/dim)) 85 | init.uniform_(self.log_alpha, -math.sqrt(1/dim), math.sqrt(1/dim)) 86 | init.uniform_(self.beta, -math.sqrt(1/dim), math.sqrt(1/dim)) 87 | 88 | def forward(self, x): 89 | """ 90 | Given x, returns z and the log-determinant log|df/dx|. 91 | """ 92 | m, n = x.shape 93 | r = torch.norm(x - self.x0) 94 | h = 1 / (torch.exp(self.log_alpha) + r) 95 | beta = -torch.exp(self.log_alpha) + torch.log(1 + torch.exp(self.beta)) 96 | z = x + beta * h * (x - self.x0) 97 | log_det = (n - 1) * torch.log(1 + beta * h) + \ 98 | torch.log(1 + beta * h - \ 99 | beta * r / (torch.exp(self.log_alpha) + r) ** 2) 100 | return z, log_det 101 | 102 | 103 | class FCNN(nn.Module): 104 | """ 105 | Simple fully connected neural network. 106 | """ 107 | def __init__(self, in_dim, out_dim, hidden_dim): 108 | super().__init__() 109 | self.network = nn.Sequential( 110 | nn.Linear(in_dim, hidden_dim), 111 | nn.Tanh(), 112 | nn.Linear(hidden_dim, hidden_dim), 113 | nn.Tanh(), 114 | nn.Linear(hidden_dim, out_dim), 115 | ) 116 | 117 | def forward(self, x): 118 | return self.network(x) 119 | 120 | 121 | class RealNVP(nn.Module): 122 | """ 123 | Non-volume preserving flow. 124 | 125 | [Dinh et. al. 2017] 126 | """ 127 | def __init__(self, dim, hidden_dim = 8, base_network=FCNN): 128 | super().__init__() 129 | self.dim = dim 130 | self.t1 = base_network(dim // 2, dim // 2, hidden_dim) 131 | self.s1 = base_network(dim // 2, dim // 2, hidden_dim) 132 | self.t2 = base_network(dim // 2, dim // 2, hidden_dim) 133 | self.s2 = base_network(dim // 2, dim // 2, hidden_dim) 134 | 135 | def forward(self, x): 136 | lower, upper = x[:,:self.dim // 2], x[:,self.dim // 2:] 137 | t1_transformed = self.t1(lower) 138 | s1_transformed = self.s1(lower) 139 | upper = t1_transformed + upper * torch.exp(s1_transformed) 140 | t2_transformed = self.t2(upper) 141 | s2_transformed = self.s2(upper) 142 | lower = t2_transformed + lower * torch.exp(s2_transformed) 143 | z = torch.cat([lower, upper], dim=1) 144 | log_det = torch.sum(s1_transformed, dim=1) + \ 145 | torch.sum(s2_transformed, dim=1) 146 | return z, log_det 147 | 148 | def inverse(self, z): 149 | lower, upper = z[:,:self.dim // 2], z[:,self.dim // 2:] 150 | t2_transformed = self.t2(upper) 151 | s2_transformed = self.s2(upper) 152 | lower = (lower - t2_transformed) * torch.exp(-s2_transformed) 153 | t1_transformed = self.t1(lower) 154 | s1_transformed = self.s1(lower) 155 | upper = (upper - t1_transformed) * torch.exp(-s1_transformed) 156 | x = torch.cat([lower, upper], dim=1) 157 | log_det = torch.sum(-s1_transformed, dim=1) + \ 158 | torch.sum(-s2_transformed, dim=1) 159 | return x, log_det 160 | 161 | 162 | class MAF(nn.Module): 163 | """ 164 | Masked auto-regressive flow. 165 | 166 | [Papamakarios et al. 2018] 167 | """ 168 | def __init__(self, dim, hidden_dim = 8, base_network=FCNN): 169 | super().__init__() 170 | self.dim = dim 171 | self.layers = nn.ModuleList() 172 | self.initial_param = nn.Parameter(torch.Tensor(2)) 173 | for i in range(1, dim): 174 | self.layers += [base_network(i, 2, hidden_dim)] 175 | self.reset_parameters() 176 | 177 | def reset_parameters(self): 178 | init.uniform_(self.initial_param, -math.sqrt(0.5), math.sqrt(0.5)) 179 | 180 | def forward(self, x): 181 | z = torch.zeros_like(x) 182 | log_det = torch.zeros(z.shape[0]) 183 | for i in range(self.dim): 184 | if i == 0: 185 | mu, alpha = self.initial_param[0], self.initial_param[1] 186 | else: 187 | out = self.layers[i - 1](x[:, :i]) 188 | mu, alpha = out[:, 0], out[:, 1] 189 | z[:, i] = (x[:, i] - mu) / torch.exp(alpha) 190 | log_det -= alpha 191 | return z.flip(dims=(1,)), log_det 192 | 193 | def inverse(self, z): 194 | x = torch.zeros_like(z) 195 | log_det = torch.zeros(z.shape[0]) 196 | z = z.flip(dims=(1,)) 197 | for i in range(self.dim): 198 | if i == 0: 199 | mu, alpha = self.initial_param[0], self.initial_param[1] 200 | else: 201 | out = self.layers[i - 1](x[:, :i]) 202 | mu, alpha = out[:, 0], out[:, 1] 203 | x[:, i] = mu + torch.exp(alpha) * z[:, i] 204 | log_det += alpha 205 | return x, log_det 206 | 207 | 208 | class ActNorm(nn.Module): 209 | """ 210 | ActNorm layer. 211 | 212 | [Kingma and Dhariwal, 2018.] 213 | """ 214 | def __init__(self, dim): 215 | super().__init__() 216 | self.dim = dim 217 | self.mu = nn.Parameter(torch.zeros(dim, dtype = torch.float)) 218 | self.log_sigma = nn.Parameter(torch.zeros(dim, dtype = torch.float)) 219 | 220 | def forward(self, x): 221 | z = x * torch.exp(self.log_sigma) + self.mu 222 | log_det = torch.sum(self.log_sigma) 223 | return z, log_det 224 | 225 | def inverse(self, z): 226 | x = (z - self.mu) / torch.exp(self.log_sigma) 227 | log_det = -torch.sum(self.log_sigma) 228 | return x, log_det 229 | 230 | 231 | class OneByOneConv(nn.Module): 232 | """ 233 | Invertible 1x1 convolution. 234 | 235 | [Kingma and Dhariwal, 2018.] 236 | """ 237 | def __init__(self, dim): 238 | super().__init__() 239 | self.dim = dim 240 | W, _ = sp.linalg.qr(np.random.randn(dim, dim)) 241 | P, L, U = sp.linalg.lu(W) 242 | self.P = torch.tensor(P, dtype = torch.float) 243 | self.L = nn.Parameter(torch.tensor(L, dtype = torch.float)) 244 | self.S = nn.Parameter(torch.tensor(np.diag(U), dtype = torch.float)) 245 | self.U = nn.Parameter(torch.triu(torch.tensor(U, dtype = torch.float), 246 | diagonal = 1)) 247 | self.W_inv = None 248 | 249 | def forward(self, x): 250 | L = torch.tril(self.L, diagonal = -1) + torch.diag(torch.ones(self.dim)) 251 | U = torch.triu(self.U, diagonal = 1) 252 | z = x @ self.P @ L @ (U + torch.diag(self.S)) 253 | log_det = torch.sum(torch.log(torch.abs(self.S))) 254 | return z, log_det 255 | 256 | def inverse(self, z): 257 | if not self.W_inv: 258 | L = torch.tril(self.L, diagonal = -1) + \ 259 | torch.diag(torch.ones(self.dim)) 260 | U = torch.triu(self.U, diagonal = 1) 261 | W = self.P @ L @ (U + torch.diag(self.S)) 262 | self.W_inv = torch.inverse(W) 263 | x = z @ self.W_inv 264 | log_det = -torch.sum(torch.log(torch.abs(self.S))) 265 | return x, log_det 266 | 267 | 268 | class NSF_AR(nn.Module): 269 | """ 270 | Neural spline flow, auto-regressive. 271 | 272 | [Durkan et al. 2019] 273 | """ 274 | def __init__(self, dim, K = 5, B = 3, hidden_dim = 8, base_network = FCNN): 275 | super().__init__() 276 | self.dim = dim 277 | self.K = K 278 | self.B = B 279 | self.layers = nn.ModuleList() 280 | self.init_param = nn.Parameter(torch.Tensor(3 * K - 1)) 281 | for i in range(1, dim): 282 | self.layers += [base_network(i, 3 * K - 1, hidden_dim)] 283 | self.reset_parameters() 284 | 285 | def reset_parameters(self): 286 | init.uniform_(self.init_param, - 1 / 2, 1 / 2) 287 | 288 | def forward(self, x): 289 | z = torch.zeros_like(x) 290 | log_det = torch.zeros(z.shape[0]) 291 | for i in range(self.dim): 292 | if i == 0: 293 | init_param = self.init_param.expand(x.shape[0], 3 * self.K - 1) 294 | W, H, D = torch.split(init_param, self.K, dim = 1) 295 | else: 296 | out = self.layers[i - 1](x[:, :i]) 297 | W, H, D = torch.split(out, self.K, dim = 1) 298 | W, H = torch.softmax(W, dim = 1), torch.softmax(H, dim = 1) 299 | W, H = 2 * self.B * W, 2 * self.B * H 300 | D = F.softplus(D) 301 | z[:, i], ld = unconstrained_RQS( 302 | x[:, i], W, H, D, inverse=False, tail_bound=self.B) 303 | log_det += ld 304 | return z, log_det 305 | 306 | def inverse(self, z): 307 | x = torch.zeros_like(z) 308 | log_det = torch.zeros(x.shape[0]) 309 | for i in range(self.dim): 310 | if i == 0: 311 | init_param = self.init_param.expand(x.shape[0], 3 * self.K - 1) 312 | W, H, D = torch.split(init_param, self.K, dim = 1) 313 | else: 314 | out = self.layers[i - 1](x[:, :i]) 315 | W, H, D = torch.split(out, self.K, dim = 1) 316 | W, H = torch.softmax(W, dim = 1), torch.softmax(H, dim = 1) 317 | W, H = 2 * self.B * W, 2 * self.B * H 318 | D = F.softplus(D) 319 | x[:, i], ld = unconstrained_RQS( 320 | z[:, i], W, H, D, inverse = True, tail_bound = self.B) 321 | log_det += ld 322 | return x, log_det 323 | 324 | 325 | class NSF_CL(nn.Module): 326 | """ 327 | Neural spline flow, coupling layer. 328 | 329 | [Durkan et al. 2019] 330 | """ 331 | def __init__(self, dim, K = 5, B = 3, hidden_dim = 8, base_network = FCNN): 332 | super().__init__() 333 | self.dim = dim 334 | self.K = K 335 | self.B = B 336 | self.f1 = base_network(dim // 2, (3 * K - 1) * dim // 2, hidden_dim) 337 | self.f2 = base_network(dim // 2, (3 * K - 1) * dim // 2, hidden_dim) 338 | 339 | def forward(self, x): 340 | log_det = torch.zeros(x.shape[0]) 341 | lower, upper = x[:, :self.dim // 2], x[:, self.dim // 2:] 342 | out = self.f1(lower).reshape(-1, self.dim // 2, 3 * self.K - 1) 343 | W, H, D = torch.split(out, self.K, dim = 2) 344 | W, H = torch.softmax(W, dim = 2), torch.softmax(H, dim = 2) 345 | W, H = 2 * self.B * W, 2 * self.B * H 346 | D = F.softplus(D) 347 | upper, ld = unconstrained_RQS( 348 | upper, W, H, D, inverse=False, tail_bound=self.B) 349 | log_det += torch.sum(ld, dim = 1) 350 | out = self.f2(upper).reshape(-1, self.dim // 2, 3 * self.K - 1) 351 | W, H, D = torch.split(out, self.K, dim = 2) 352 | W, H = torch.softmax(W, dim = 2), torch.softmax(H, dim = 2) 353 | W, H = 2 * self.B * W, 2 * self.B * H 354 | D = F.softplus(D) 355 | lower, ld = unconstrained_RQS( 356 | lower, W, H, D, inverse=False, tail_bound=self.B) 357 | log_det += torch.sum(ld, dim = 1) 358 | return torch.cat([lower, upper], dim = 1), log_det 359 | 360 | def inverse(self, z): 361 | log_det = torch.zeros(z.shape[0]) 362 | lower, upper = z[:, :self.dim // 2], z[:, self.dim // 2:] 363 | out = self.f2(upper).reshape(-1, self.dim // 2, 3 * self.K - 1) 364 | W, H, D = torch.split(out, self.K, dim = 2) 365 | W, H = torch.softmax(W, dim = 2), torch.softmax(H, dim = 2) 366 | W, H = 2 * self.B * W, 2 * self.B * H 367 | D = F.softplus(D) 368 | lower, ld = unconstrained_RQS( 369 | lower, W, H, D, inverse=True, tail_bound=self.B) 370 | log_det += torch.sum(ld, dim = 1) 371 | out = self.f1(lower).reshape(-1, self.dim // 2, 3 * self.K - 1) 372 | W, H, D = torch.split(out, self.K, dim = 2) 373 | W, H = torch.softmax(W, dim = 2), torch.softmax(H, dim = 2) 374 | W, H = 2 * self.B * W, 2 * self.B * H 375 | D = F.softplus(D) 376 | upper, ld = unconstrained_RQS( 377 | upper, W, H, D, inverse = True, tail_bound = self.B) 378 | log_det += torch.sum(ld, dim = 1) 379 | return torch.cat([lower, upper], dim = 1), log_det 380 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import MultivariateNormal 4 | 5 | 6 | class NormalizingFlowModel(nn.Module): 7 | 8 | def __init__(self, dim, flows): 9 | super().__init__() 10 | self.prior = MultivariateNormal(torch.zeros(dim), torch.eye(dim)) 11 | self.flows = nn.ModuleList(flows) 12 | 13 | def forward(self, x): 14 | bsz, _ = x.shape 15 | log_det = torch.zeros(bsz) 16 | for flow in self.flows: 17 | x, ld = flow.forward(x) 18 | log_det += ld 19 | z, prior_logprob = x, self.prior.log_prob(x) 20 | return z, prior_logprob, log_det 21 | 22 | def inverse(self, z): 23 | bsz, _ = z.shape 24 | log_det = torch.zeros(bsz) 25 | for flow in self.flows[::-1]: 26 | z, ld = flow.inverse(z) 27 | log_det += ld 28 | x = z 29 | return x, log_det 30 | 31 | def sample(self, n_samples): 32 | z = self.prior.sample((n_samples,)) 33 | x, _ = self.inverse(z) 34 | return x 35 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | """ 6 | Implementation of rational-quadratic splines in this file is taken from 7 | https://github.com/bayesiains/nsf. 8 | 9 | Thank you to the authors for providing well-documented source code! 10 | """ 11 | 12 | DEFAULT_MIN_BIN_WIDTH = 1e-3 13 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 14 | DEFAULT_MIN_DERIVATIVE = 1e-3 15 | 16 | def searchsorted(bin_locations, inputs, eps=1e-6): 17 | bin_locations[..., -1] += eps 18 | return torch.sum( 19 | inputs[..., None] >= bin_locations, 20 | dim=-1 21 | ) - 1 22 | 23 | def unconstrained_RQS(inputs, unnormalized_widths, unnormalized_heights, 24 | unnormalized_derivatives, inverse=False, 25 | tail_bound=1., min_bin_width=DEFAULT_MIN_BIN_WIDTH, 26 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 27 | min_derivative=DEFAULT_MIN_DERIVATIVE): 28 | inside_intvl_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 29 | outside_interval_mask = ~inside_intvl_mask 30 | 31 | outputs = torch.zeros_like(inputs) 32 | logabsdet = torch.zeros_like(inputs) 33 | 34 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 35 | constant = np.log(np.exp(1 - min_derivative) - 1) 36 | unnormalized_derivatives[..., 0] = constant 37 | unnormalized_derivatives[..., -1] = constant 38 | 39 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 40 | logabsdet[outside_interval_mask] = 0 41 | 42 | outputs[inside_intvl_mask], logabsdet[inside_intvl_mask] = RQS( 43 | inputs=inputs[inside_intvl_mask], 44 | unnormalized_widths=unnormalized_widths[inside_intvl_mask, :], 45 | unnormalized_heights=unnormalized_heights[inside_intvl_mask, :], 46 | unnormalized_derivatives=unnormalized_derivatives[inside_intvl_mask, :], 47 | inverse=inverse, 48 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 49 | min_bin_width=min_bin_width, 50 | min_bin_height=min_bin_height, 51 | min_derivative=min_derivative 52 | ) 53 | return outputs, logabsdet 54 | 55 | def RQS(inputs, unnormalized_widths, unnormalized_heights, 56 | unnormalized_derivatives, inverse=False, left=0., right=1., 57 | bottom=0., top=1., min_bin_width=DEFAULT_MIN_BIN_WIDTH, 58 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 59 | min_derivative=DEFAULT_MIN_DERIVATIVE): 60 | if torch.min(inputs) < left or torch.max(inputs) > right: 61 | raise ValueError("Input outside domain") 62 | 63 | num_bins = unnormalized_widths.shape[-1] 64 | 65 | if min_bin_width * num_bins > 1.0: 66 | raise ValueError('Minimal bin width too large for the number of bins') 67 | if min_bin_height * num_bins > 1.0: 68 | raise ValueError('Minimal bin height too large for the number of bins') 69 | 70 | widths = F.softmax(unnormalized_widths, dim=-1) 71 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 72 | cumwidths = torch.cumsum(widths, dim=-1) 73 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 74 | cumwidths = (right - left) * cumwidths + left 75 | cumwidths[..., 0] = left 76 | cumwidths[..., -1] = right 77 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 78 | 79 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 80 | 81 | heights = F.softmax(unnormalized_heights, dim=-1) 82 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 83 | cumheights = torch.cumsum(heights, dim=-1) 84 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 85 | cumheights = (top - bottom) * cumheights + bottom 86 | cumheights[..., 0] = bottom 87 | cumheights[..., -1] = top 88 | heights = cumheights[..., 1:] - cumheights[..., :-1] 89 | 90 | if inverse: 91 | bin_idx = searchsorted(cumheights, inputs)[..., None] 92 | else: 93 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 94 | 95 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 96 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 97 | 98 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 99 | delta = heights / widths 100 | input_delta = delta.gather(-1, bin_idx)[..., 0] 101 | 102 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 103 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx) 104 | input_derivatives_plus_one = input_derivatives_plus_one[..., 0] 105 | 106 | input_heights = heights.gather(-1, bin_idx)[..., 0] 107 | 108 | if inverse: 109 | a = (((inputs - input_cumheights) * (input_derivatives \ 110 | + input_derivatives_plus_one - 2 * input_delta) \ 111 | + input_heights * (input_delta - input_derivatives))) 112 | b = (input_heights * input_derivatives - (inputs - input_cumheights) \ 113 | * (input_derivatives + input_derivatives_plus_one \ 114 | - 2 * input_delta)) 115 | c = - input_delta * (inputs - input_cumheights) 116 | 117 | discriminant = b.pow(2) - 4 * a * c 118 | assert (discriminant >= 0).all() 119 | 120 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 121 | outputs = root * input_bin_widths + input_cumwidths 122 | 123 | theta_one_minus_theta = root * (1 - root) 124 | denominator = input_delta \ 125 | + ((input_derivatives + input_derivatives_plus_one \ 126 | - 2 * input_delta) * theta_one_minus_theta) 127 | derivative_numerator = input_delta.pow(2) \ 128 | * (input_derivatives_plus_one * root.pow(2) \ 129 | + 2 * input_delta * theta_one_minus_theta \ 130 | + input_derivatives * (1 - root).pow(2)) 131 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 132 | return outputs, -logabsdet 133 | else: 134 | theta = (inputs - input_cumwidths) / input_bin_widths 135 | theta_one_minus_theta = theta * (1 - theta) 136 | 137 | numerator = input_heights * (input_delta * theta.pow(2) \ 138 | + input_derivatives * theta_one_minus_theta) 139 | denominator = input_delta + ((input_derivatives \ 140 | + input_derivatives_plus_one - 2 * input_delta) \ 141 | * theta_one_minus_theta) 142 | outputs = input_cumheights + numerator / denominator 143 | 144 | derivative_numerator = input_delta.pow(2) \ 145 | * (input_derivatives_plus_one * theta.pow(2) \ 146 | + 2 * input_delta * theta_one_minus_theta \ 147 | + input_derivatives * (1 - theta).pow(2)) 148 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 149 | return outputs, logabsdet 150 | --------------------------------------------------------------------------------