├── .gitignore ├── README.md ├── assets ├── cubic_d.svg ├── density.png ├── flow_16.png ├── flow_2.png ├── flow_4.png ├── linear_d.svg └── simple_d.svg ├── densities.py ├── flow.py ├── losses.py ├── requirements.txt ├── run_experiment.py ├── utils.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | 3 | env/ 4 | experiments/ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational Inference with Normalizing Flows 2 | 3 | Reimplementation of Variational Inference with Normalizing Flows (https://arxiv.org/abs/1505.05770) 4 | 5 | The idea is to approximate a complex multimodal probability density with a simple probability density followed by a sequence of invertible nonlinear transforms. Inference in such model requires a computation of multiple Jacobian determinants, that can be computationaly expensive. Authors propose a specific form of the transformation that reduces the cost of computing the Jacobians from approximately ![](/assets/cubic_d.svg) to ![](/assets/linear_d.svg) where ![](/assets/simple_d.svg) is the dimensionality of the data. 6 | 7 | **NOTE**: Currently I provide implementation for the simple case, where the true density can be expressed in a closed form, so it's possible to explicitly minimize KL-divergence between the true density and the density represented by a normalizing flow. Implementing the most general case of normalizing which is capable of learning from the raw data is a bit problematic for the transformation described in the paper since inverse function for such transformation can not be expressed in a closed form. Currently I'm working on another kind of normalizing flow called [Glow](https://arxiv.org/abs/1807.03039) where all the transformations have closed-form inverse functions, and I'm planning to release it soon. Stay tuned! 8 | 9 | I got the following results: 10 | 11 |

12 | 13 | 14 |

15 | 16 | 17 |

18 | 19 | 20 |

21 | 22 | As can be seen, the approximation quality indeed increases as the flow length gets higher. 23 | 24 | ### Reproducing my results 25 | 26 | To reproduce my results, you will need to install [pytorch](http://pytorch.org/). 27 | 28 | Then you will need to install other dependencies from ```requirements.txt```. If you are using ```pip```, simply run ```pip install -r requirements.txt```. 29 | 30 | After you have installed the dependencies, run ```python run_experiment.py``` and collect the results in the ```experiments``` folder. 31 | -------------------------------------------------------------------------------- /assets/cubic_d.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /assets/density.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ex4sperans/variational-inference-with-normalizing-flows/922b569f851e02fa74700cd0754fe2ef5c1f3180/assets/density.png -------------------------------------------------------------------------------- /assets/flow_16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ex4sperans/variational-inference-with-normalizing-flows/922b569f851e02fa74700cd0754fe2ef5c1f3180/assets/flow_16.png -------------------------------------------------------------------------------- /assets/flow_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ex4sperans/variational-inference-with-normalizing-flows/922b569f851e02fa74700cd0754fe2ef5c1f3180/assets/flow_2.png -------------------------------------------------------------------------------- /assets/flow_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ex4sperans/variational-inference-with-normalizing-flows/922b569f851e02fa74700cd0754fe2ef5c1f3180/assets/flow_4.png -------------------------------------------------------------------------------- /assets/linear_d.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /assets/simple_d.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /densities.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def p_z(z): 5 | 6 | z1, z2 = torch.chunk(z, chunks=2, dim=1) 7 | norm = torch.sqrt(z1 ** 2 + z2 ** 2) 8 | 9 | exp1 = torch.exp(-0.5 * ((z1 - 2) / 0.8) ** 2) 10 | exp2 = torch.exp(-0.5 * ((z1 + 2) / 0.8) ** 2) 11 | u = 0.5 * ((norm - 4) / 0.4) ** 2 - torch.log(exp1 + exp2) 12 | 13 | return torch.exp(-u) 14 | -------------------------------------------------------------------------------- /flow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.autograd import Variable 5 | 6 | from utils import safe_log 7 | 8 | 9 | class NormalizingFlow(nn.Module): 10 | 11 | def __init__(self, dim, flow_length): 12 | super().__init__() 13 | 14 | self.transforms = nn.Sequential(*( 15 | PlanarFlow(dim) for _ in range(flow_length) 16 | )) 17 | 18 | self.log_jacobians = nn.Sequential(*( 19 | PlanarFlowLogDetJacobian(t) for t in self.transforms 20 | )) 21 | 22 | def forward(self, z): 23 | 24 | log_jacobians = [] 25 | 26 | for transform, log_jacobian in zip(self.transforms, self.log_jacobians): 27 | log_jacobians.append(log_jacobian(z)) 28 | z = transform(z) 29 | 30 | zk = z 31 | 32 | return zk, log_jacobians 33 | 34 | 35 | class PlanarFlow(nn.Module): 36 | 37 | def __init__(self, dim): 38 | super().__init__() 39 | 40 | self.weight = nn.Parameter(torch.Tensor(1, dim)) 41 | self.bias = nn.Parameter(torch.Tensor(1)) 42 | self.scale = nn.Parameter(torch.Tensor(1, dim)) 43 | self.tanh = nn.Tanh() 44 | 45 | self.reset_parameters() 46 | 47 | def reset_parameters(self): 48 | 49 | self.weight.data.uniform_(-0.01, 0.01) 50 | self.scale.data.uniform_(-0.01, 0.01) 51 | self.bias.data.uniform_(-0.01, 0.01) 52 | 53 | def forward(self, z): 54 | 55 | activation = F.linear(z, self.weight, self.bias) 56 | return z + self.scale * self.tanh(activation) 57 | 58 | 59 | class PlanarFlowLogDetJacobian(nn.Module): 60 | """A helper class to compute the determinant of the gradient of 61 | the planar flow transformation.""" 62 | 63 | def __init__(self, affine): 64 | super().__init__() 65 | 66 | self.weight = affine.weight 67 | self.bias = affine.bias 68 | self.scale = affine.scale 69 | self.tanh = affine.tanh 70 | 71 | def forward(self, z): 72 | 73 | activation = F.linear(z, self.weight, self.bias) 74 | psi = (1 - self.tanh(activation) ** 2) * self.weight 75 | det_grad = 1 + torch.mm(psi, self.scale.t()) 76 | return safe_log(det_grad.abs()) 77 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from utils import safe_log 5 | 6 | 7 | class FreeEnergyBound(nn.Module): 8 | 9 | def __init__(self, density): 10 | super().__init__() 11 | 12 | self.density = density 13 | 14 | def forward(self, zk, log_jacobians): 15 | 16 | sum_of_log_jacobians = sum(log_jacobians) 17 | return (-sum_of_log_jacobians - safe_log(self.density(zk))).mean() 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | kiwisolver==1.0.1 3 | matplotlib==2.2.2 4 | numpy==1.14.2 5 | Pillow==5.0.0 6 | pkg-resources==0.0.0 7 | pyparsing==2.2.0 8 | python-dateutil==2.7.0 9 | pytz==2018.3 10 | PyYAML==3.12 11 | six==1.11.0 12 | 13 | git+https://github.com/ex4sperans/mag 14 | -------------------------------------------------------------------------------- /run_experiment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch.autograd import Variable 5 | from torch import optim 6 | from mag.experiment import Experiment 7 | 8 | from visualization import plot_density, scatter_points 9 | from utils import random_normal_samples 10 | from flow import NormalizingFlow 11 | from losses import FreeEnergyBound 12 | from densities import p_z 13 | 14 | 15 | parser = argparse.ArgumentParser( 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 17 | ) 18 | 19 | parser.add_argument( 20 | "--log_interval", type=int, default=300, 21 | help="How frequenlty to print the training stats." 22 | ) 23 | parser.add_argument( 24 | "--plot_interval", type=int, default=300, 25 | help="How frequenlty to plot samples from current distribution." 26 | ) 27 | parser.add_argument( 28 | "--plot_points", type=int, default=1000, 29 | help="How many to points to generate for one plot." 30 | ) 31 | 32 | args = parser.parse_args() 33 | 34 | torch.manual_seed(42) 35 | 36 | 37 | with Experiment({ 38 | "batch_size": 40, 39 | "iterations": 10000, 40 | "initial_lr": 0.01, 41 | "lr_decay": 0.999, 42 | "flow_length": 16, 43 | "name": "planar" 44 | }) as experiment: 45 | 46 | config = experiment.config 47 | experiment.register_directory("samples") 48 | experiment.register_directory("distributions") 49 | 50 | flow = NormalizingFlow(dim=2, flow_length=config.flow_length) 51 | bound = FreeEnergyBound(density=p_z) 52 | optimizer = optim.RMSprop(flow.parameters(), lr=config.initial_lr) 53 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, config.lr_decay) 54 | 55 | plot_density(p_z, directory=experiment.distributions) 56 | 57 | def should_log(iteration): 58 | return iteration % args.log_interval == 0 59 | 60 | def should_plot(iteration): 61 | return iteration % args.plot_interval == 0 62 | 63 | for iteration in range(1, config.iterations + 1): 64 | 65 | scheduler.step() 66 | 67 | samples = Variable(random_normal_samples(config.batch_size)) 68 | zk, log_jacobians = flow(samples) 69 | 70 | optimizer.zero_grad() 71 | loss = bound(zk, log_jacobians) 72 | loss.backward() 73 | optimizer.step() 74 | 75 | if should_log(iteration): 76 | print("Loss on iteration {}: {}".format(iteration , loss.data[0])) 77 | 78 | if should_plot(iteration): 79 | samples = Variable(random_normal_samples(args.plot_points)) 80 | zk, det_grads = flow(samples) 81 | scatter_points( 82 | zk.data.numpy(), 83 | directory=experiment.samples, 84 | iteration=iteration, 85 | flow_length=config.flow_length 86 | ) 87 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def safe_log(z): 5 | return torch.log(z + 1e-7) 6 | 7 | 8 | def random_normal_samples(n, dim=2): 9 | return torch.zeros(n, dim).normal_(mean=0, std=1) -------------------------------------------------------------------------------- /visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from torch.autograd import Variable 6 | from matplotlib import pyplot as plt 7 | 8 | 9 | def scatter_points(points, directory, iteration, flow_length): 10 | 11 | X_LIMS = (-7, 7) 12 | Y_LIMS = (-7, 7) 13 | 14 | fig = plt.figure(figsize=(7, 7)) 15 | ax = fig.add_subplot(111) 16 | ax.scatter(points[:, 0], points[:, 1], alpha=0.7, s=25) 17 | ax.set_xlim(*X_LIMS) 18 | ax.set_ylim(*Y_LIMS) 19 | ax.set_title( 20 | "Flow length: {}\n Samples on iteration #{}" 21 | .format(flow_length, iteration) 22 | ) 23 | 24 | fig.savefig(os.path.join(directory, "flow_result_{}.png".format(iteration))) 25 | plt.close() 26 | 27 | 28 | def plot_density(density, directory): 29 | 30 | X_LIMS = (-7, 7) 31 | Y_LIMS = (-7, 7) 32 | 33 | x1 = np.linspace(*X_LIMS, 300) 34 | x2 = np.linspace(*Y_LIMS, 300) 35 | x1, x2 = np.meshgrid(x1, x2) 36 | shape = x1.shape 37 | x1 = x1.ravel() 38 | x2 = x2.ravel() 39 | 40 | z = np.c_[x1, x2] 41 | z = torch.FloatTensor(z) 42 | z = Variable(z) 43 | 44 | density_values = density(z).data.numpy().reshape(shape) 45 | 46 | fig = plt.figure(figsize=(7, 7)) 47 | ax = fig.add_subplot(111) 48 | ax.imshow(density_values, extent=(*X_LIMS, *Y_LIMS), cmap="summer") 49 | ax.set_title("True density") 50 | 51 | fig.savefig(os.path.join(directory, "density.png")) 52 | plt.close() 53 | --------------------------------------------------------------------------------