├── .gitignore
├── README.md
├── assets
├── cubic_d.svg
├── density.png
├── flow_16.png
├── flow_2.png
├── flow_4.png
├── linear_d.svg
└── simple_d.svg
├── densities.py
├── flow.py
├── losses.py
├── requirements.txt
├── run_experiment.py
├── utils.py
└── visualization.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 |
3 | env/
4 | experiments/
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Variational Inference with Normalizing Flows
2 |
3 | Reimplementation of Variational Inference with Normalizing Flows (https://arxiv.org/abs/1505.05770)
4 |
5 | The idea is to approximate a complex multimodal probability density with a simple probability density followed by a sequence of invertible nonlinear transforms. Inference in such model requires a computation of multiple Jacobian determinants, that can be computationaly expensive. Authors propose a specific form of the transformation that reduces the cost of computing the Jacobians from approximately  to  where  is the dimensionality of the data.
6 |
7 | **NOTE**: Currently I provide implementation for the simple case, where the true density can be expressed in a closed form, so it's possible to explicitly minimize KL-divergence between the true density and the density represented by a normalizing flow. Implementing the most general case of normalizing which is capable of learning from the raw data is a bit problematic for the transformation described in the paper since inverse function for such transformation can not be expressed in a closed form. Currently I'm working on another kind of normalizing flow called [Glow](https://arxiv.org/abs/1807.03039) where all the transformations have closed-form inverse functions, and I'm planning to release it soon. Stay tuned!
8 |
9 | I got the following results:
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 | As can be seen, the approximation quality indeed increases as the flow length gets higher.
23 |
24 | ### Reproducing my results
25 |
26 | To reproduce my results, you will need to install [pytorch](http://pytorch.org/).
27 |
28 | Then you will need to install other dependencies from ```requirements.txt```. If you are using ```pip```, simply run ```pip install -r requirements.txt```.
29 |
30 | After you have installed the dependencies, run ```python run_experiment.py``` and collect the results in the ```experiments``` folder.
31 |
--------------------------------------------------------------------------------
/assets/cubic_d.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/assets/density.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ex4sperans/variational-inference-with-normalizing-flows/922b569f851e02fa74700cd0754fe2ef5c1f3180/assets/density.png
--------------------------------------------------------------------------------
/assets/flow_16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ex4sperans/variational-inference-with-normalizing-flows/922b569f851e02fa74700cd0754fe2ef5c1f3180/assets/flow_16.png
--------------------------------------------------------------------------------
/assets/flow_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ex4sperans/variational-inference-with-normalizing-flows/922b569f851e02fa74700cd0754fe2ef5c1f3180/assets/flow_2.png
--------------------------------------------------------------------------------
/assets/flow_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ex4sperans/variational-inference-with-normalizing-flows/922b569f851e02fa74700cd0754fe2ef5c1f3180/assets/flow_4.png
--------------------------------------------------------------------------------
/assets/linear_d.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/assets/simple_d.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/densities.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def p_z(z):
5 |
6 | z1, z2 = torch.chunk(z, chunks=2, dim=1)
7 | norm = torch.sqrt(z1 ** 2 + z2 ** 2)
8 |
9 | exp1 = torch.exp(-0.5 * ((z1 - 2) / 0.8) ** 2)
10 | exp2 = torch.exp(-0.5 * ((z1 + 2) / 0.8) ** 2)
11 | u = 0.5 * ((norm - 4) / 0.4) ** 2 - torch.log(exp1 + exp2)
12 |
13 | return torch.exp(-u)
14 |
--------------------------------------------------------------------------------
/flow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from torch.autograd import Variable
5 |
6 | from utils import safe_log
7 |
8 |
9 | class NormalizingFlow(nn.Module):
10 |
11 | def __init__(self, dim, flow_length):
12 | super().__init__()
13 |
14 | self.transforms = nn.Sequential(*(
15 | PlanarFlow(dim) for _ in range(flow_length)
16 | ))
17 |
18 | self.log_jacobians = nn.Sequential(*(
19 | PlanarFlowLogDetJacobian(t) for t in self.transforms
20 | ))
21 |
22 | def forward(self, z):
23 |
24 | log_jacobians = []
25 |
26 | for transform, log_jacobian in zip(self.transforms, self.log_jacobians):
27 | log_jacobians.append(log_jacobian(z))
28 | z = transform(z)
29 |
30 | zk = z
31 |
32 | return zk, log_jacobians
33 |
34 |
35 | class PlanarFlow(nn.Module):
36 |
37 | def __init__(self, dim):
38 | super().__init__()
39 |
40 | self.weight = nn.Parameter(torch.Tensor(1, dim))
41 | self.bias = nn.Parameter(torch.Tensor(1))
42 | self.scale = nn.Parameter(torch.Tensor(1, dim))
43 | self.tanh = nn.Tanh()
44 |
45 | self.reset_parameters()
46 |
47 | def reset_parameters(self):
48 |
49 | self.weight.data.uniform_(-0.01, 0.01)
50 | self.scale.data.uniform_(-0.01, 0.01)
51 | self.bias.data.uniform_(-0.01, 0.01)
52 |
53 | def forward(self, z):
54 |
55 | activation = F.linear(z, self.weight, self.bias)
56 | return z + self.scale * self.tanh(activation)
57 |
58 |
59 | class PlanarFlowLogDetJacobian(nn.Module):
60 | """A helper class to compute the determinant of the gradient of
61 | the planar flow transformation."""
62 |
63 | def __init__(self, affine):
64 | super().__init__()
65 |
66 | self.weight = affine.weight
67 | self.bias = affine.bias
68 | self.scale = affine.scale
69 | self.tanh = affine.tanh
70 |
71 | def forward(self, z):
72 |
73 | activation = F.linear(z, self.weight, self.bias)
74 | psi = (1 - self.tanh(activation) ** 2) * self.weight
75 | det_grad = 1 + torch.mm(psi, self.scale.t())
76 | return safe_log(det_grad.abs())
77 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from utils import safe_log
5 |
6 |
7 | class FreeEnergyBound(nn.Module):
8 |
9 | def __init__(self, density):
10 | super().__init__()
11 |
12 | self.density = density
13 |
14 | def forward(self, zk, log_jacobians):
15 |
16 | sum_of_log_jacobians = sum(log_jacobians)
17 | return (-sum_of_log_jacobians - safe_log(self.density(zk))).mean()
18 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cycler==0.10.0
2 | kiwisolver==1.0.1
3 | matplotlib==2.2.2
4 | numpy==1.14.2
5 | Pillow==5.0.0
6 | pkg-resources==0.0.0
7 | pyparsing==2.2.0
8 | python-dateutil==2.7.0
9 | pytz==2018.3
10 | PyYAML==3.12
11 | six==1.11.0
12 |
13 | git+https://github.com/ex4sperans/mag
14 |
--------------------------------------------------------------------------------
/run_experiment.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from torch.autograd import Variable
5 | from torch import optim
6 | from mag.experiment import Experiment
7 |
8 | from visualization import plot_density, scatter_points
9 | from utils import random_normal_samples
10 | from flow import NormalizingFlow
11 | from losses import FreeEnergyBound
12 | from densities import p_z
13 |
14 |
15 | parser = argparse.ArgumentParser(
16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter
17 | )
18 |
19 | parser.add_argument(
20 | "--log_interval", type=int, default=300,
21 | help="How frequenlty to print the training stats."
22 | )
23 | parser.add_argument(
24 | "--plot_interval", type=int, default=300,
25 | help="How frequenlty to plot samples from current distribution."
26 | )
27 | parser.add_argument(
28 | "--plot_points", type=int, default=1000,
29 | help="How many to points to generate for one plot."
30 | )
31 |
32 | args = parser.parse_args()
33 |
34 | torch.manual_seed(42)
35 |
36 |
37 | with Experiment({
38 | "batch_size": 40,
39 | "iterations": 10000,
40 | "initial_lr": 0.01,
41 | "lr_decay": 0.999,
42 | "flow_length": 16,
43 | "name": "planar"
44 | }) as experiment:
45 |
46 | config = experiment.config
47 | experiment.register_directory("samples")
48 | experiment.register_directory("distributions")
49 |
50 | flow = NormalizingFlow(dim=2, flow_length=config.flow_length)
51 | bound = FreeEnergyBound(density=p_z)
52 | optimizer = optim.RMSprop(flow.parameters(), lr=config.initial_lr)
53 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, config.lr_decay)
54 |
55 | plot_density(p_z, directory=experiment.distributions)
56 |
57 | def should_log(iteration):
58 | return iteration % args.log_interval == 0
59 |
60 | def should_plot(iteration):
61 | return iteration % args.plot_interval == 0
62 |
63 | for iteration in range(1, config.iterations + 1):
64 |
65 | scheduler.step()
66 |
67 | samples = Variable(random_normal_samples(config.batch_size))
68 | zk, log_jacobians = flow(samples)
69 |
70 | optimizer.zero_grad()
71 | loss = bound(zk, log_jacobians)
72 | loss.backward()
73 | optimizer.step()
74 |
75 | if should_log(iteration):
76 | print("Loss on iteration {}: {}".format(iteration , loss.data[0]))
77 |
78 | if should_plot(iteration):
79 | samples = Variable(random_normal_samples(args.plot_points))
80 | zk, det_grads = flow(samples)
81 | scatter_points(
82 | zk.data.numpy(),
83 | directory=experiment.samples,
84 | iteration=iteration,
85 | flow_length=config.flow_length
86 | )
87 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def safe_log(z):
5 | return torch.log(z + 1e-7)
6 |
7 |
8 | def random_normal_samples(n, dim=2):
9 | return torch.zeros(n, dim).normal_(mean=0, std=1)
--------------------------------------------------------------------------------
/visualization.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | from torch.autograd import Variable
6 | from matplotlib import pyplot as plt
7 |
8 |
9 | def scatter_points(points, directory, iteration, flow_length):
10 |
11 | X_LIMS = (-7, 7)
12 | Y_LIMS = (-7, 7)
13 |
14 | fig = plt.figure(figsize=(7, 7))
15 | ax = fig.add_subplot(111)
16 | ax.scatter(points[:, 0], points[:, 1], alpha=0.7, s=25)
17 | ax.set_xlim(*X_LIMS)
18 | ax.set_ylim(*Y_LIMS)
19 | ax.set_title(
20 | "Flow length: {}\n Samples on iteration #{}"
21 | .format(flow_length, iteration)
22 | )
23 |
24 | fig.savefig(os.path.join(directory, "flow_result_{}.png".format(iteration)))
25 | plt.close()
26 |
27 |
28 | def plot_density(density, directory):
29 |
30 | X_LIMS = (-7, 7)
31 | Y_LIMS = (-7, 7)
32 |
33 | x1 = np.linspace(*X_LIMS, 300)
34 | x2 = np.linspace(*Y_LIMS, 300)
35 | x1, x2 = np.meshgrid(x1, x2)
36 | shape = x1.shape
37 | x1 = x1.ravel()
38 | x2 = x2.ravel()
39 |
40 | z = np.c_[x1, x2]
41 | z = torch.FloatTensor(z)
42 | z = Variable(z)
43 |
44 | density_values = density(z).data.numpy().reshape(shape)
45 |
46 | fig = plt.figure(figsize=(7, 7))
47 | ax = fig.add_subplot(111)
48 | ax.imshow(density_values, extent=(*X_LIMS, *Y_LIMS), cmap="summer")
49 | ax.set_title("True density")
50 |
51 | fig.savefig(os.path.join(directory, "density.png"))
52 | plt.close()
53 |
--------------------------------------------------------------------------------