├── imgs
├── acc.png
├── mse.png
├── fce-maf.png
├── value.png
└── fce-glow.png
├── .gitattributes
├── ebm.py
├── LICENSE
├── train.py
├── flows
├── glow.py
└── maf.py
├── README.md
└── util.py
/imgs/acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/volagold/fce-2d/HEAD/imgs/acc.png
--------------------------------------------------------------------------------
/imgs/mse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/volagold/fce-2d/HEAD/imgs/mse.png
--------------------------------------------------------------------------------
/imgs/fce-maf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/volagold/fce-2d/HEAD/imgs/fce-maf.png
--------------------------------------------------------------------------------
/imgs/value.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/volagold/fce-2d/HEAD/imgs/value.png
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/imgs/fce-glow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/volagold/fce-2d/HEAD/imgs/fce-glow.png
--------------------------------------------------------------------------------
/ebm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class EBM(nn.Module):
6 | def __init__(self):
7 | super(EBM, self).__init__()
8 | # The normalizing constant logZ(θ)
9 | self.c = nn.Parameter(torch.tensor([1.], requires_grad=True))
10 |
11 | self.f = nn.Sequential(
12 | nn.Linear(2, 128),
13 | nn.LeakyReLU(0.2, inplace=True),
14 | nn.Linear(128, 128),
15 | nn.LeakyReLU(0.2, inplace=True),
16 | nn.Linear(128, 1),
17 | )
18 |
19 | def forward(self, x):
20 | log_prob = - self.f(x) - self.c
21 | return log_prob
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 franli
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of Flow Contrastive Estimation (FCE) on 2D dataset.
3 | https://arxiv.org/abs/1912.00589
4 | """
5 | import os
6 | import math
7 | import argparse
8 | import torch
9 | from ebm import EBM
10 | import util
11 |
12 | import wandb
13 |
14 |
15 | parser = argparse.ArgumentParser(description='Flow Contrastive Estimation of Energy Based Model')
16 | parser.add_argument('--seed', default=42, type=int, help='seed')
17 | parser.add_argument('--epoch', default=100, type=int, help='number of training epochs')
18 | parser.add_argument('--flow', default='glow', type=str, help='Flow model to use')
19 | parser.add_argument('--threshold', default=0.6, type=float, help='threshold for alternate training')
20 | parser.add_argument('--batch', default=1000, type=int, help='batch size')
21 | parser.add_argument('--dataset', default='8gaussians', type=str, choices=['8gaussians', 'spiral', '2spirals', 'checkerboard', 'rings', 'pinwheel'], help='2D dataset to use')
22 | parser.add_argument('--samples', default=500000, type=int, help='number of 2D samples for training')
23 | parser.add_argument('--lr_ebm', default=1e-3, type=float, help='learning rate for EBM')
24 | parser.add_argument('--lr_flow', default=7e-4, type=float, help='learning rate for Flow')
25 | parser.add_argument('--b1', type=float, default=0.9, help='adam: decay of first order momentum of gradient')
26 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
27 | args = parser.parse_args()
28 |
29 | wandb.init(project='FCE-2d')
30 | wandb.config.update(args)
31 |
32 | device = torch.device('cuda')
33 | torch.manual_seed(args.seed)
34 | # ------------------------------
35 | # I. MODELS
36 | # ------------------------------
37 | energy = EBM().to(device)
38 | if args.flow == 'glow':
39 | from flows.glow import Glow
40 | flow = Glow(width=64, depth=5, n_levels=1, data_dim=2).to(device)
41 | elif args.flow == 'maf':
42 | from flows.maf import MAF
43 | flow = MAF(n_blocks=5, input_size=2, hidden_size=100, n_hidden=1).to(device)
44 | # ------------------------------
45 | # II. OPTIMIZERS
46 | # ------------------------------
47 | optim_energy = torch.optim.Adam(energy.parameters(), lr=args.lr_ebm, betas=(args.b1, args.b2))
48 | optim_flow = torch.optim.Adam(flow.parameters(), lr=args.lr_flow, betas=(args.b1, args.b2))
49 | # ------------------------------
50 | # III. DATA LOADER
51 | # ------------------------------
52 | dataset, dataloader = util.get_data(args)
53 | dataset = dataset.to(device)
54 | # ------------------------------
55 | # IV. TRAINING
56 | # ------------------------------
57 | wandb.watch(energy)
58 | wandb.watch(flow)
59 |
60 | def main(args):
61 | train_energy = True
62 | for epoch in range(args.epoch):
63 | for i, x in enumerate(dataloader):
64 | x = x.to(device)
65 | # -----------------------------
66 | # Generate noise
67 | # -----------------------------
68 | z = flow.base_dist.sample((args.batch,))
69 | # -----------------------------
70 | # Train Energy Model
71 | # -----------------------------
72 | if train_energy:
73 | optim_energy.zero_grad()
74 | loss_energy, acc = util.value(energy, flow, x, z, maximize=True)
75 | loss_energy.backward()
76 | optim_energy.step()
77 | # -----------------------------
78 | # Train Flow Model
79 | # -----------------------------
80 | else:
81 | optim_flow.zero_grad()
82 | loss_flow, acc = util.value(energy, flow, x, z, maximize=False)
83 | loss_flow.backward()
84 | optim_flow.step()
85 |
86 | wandb.log({'epoch': epoch,
87 | 'value': loss_energy.item() if train_energy else -loss_flow.item(),
88 | 'acc': acc,
89 | 'mse': util.mse(energy, util.MixedGaussian(device=device), dataset), # comment out if not using 8gaussians
90 | })
91 |
92 |
93 | if acc > args.threshold:
94 | train_energy = False
95 | else:
96 | train_energy = True
97 |
98 |
99 | # Save checkpoint
100 | print('Saving models...')
101 | state = {
102 | 'energy': energy.state_dict(),
103 | 'flow': flow.state_dict(),
104 | 'value': loss_energy,
105 | 'epoch': epoch,
106 | }
107 | os.makedirs('ckpts', exist_ok=True)
108 | ckpts = 'ckpts/fce-{}-2d-{}.pth.tar'.format(args.flow, args.dataset)
109 | torch.save(state, ckpts)
110 |
111 | # visualization
112 | # util.plot(dataset, energy, flow, epoch, device)
113 |
114 |
115 |
116 |
117 |
118 | if __name__ == '__main__':
119 | print(args)
120 | main(args)
121 |
--------------------------------------------------------------------------------
/flows/glow.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of Glow[1] model in PyTorch for 2D dataset. Adapted from
3 | https://github.com/kamenbliznashki/normalizing_flows/blob/master/glow.py
4 |
5 | [1] https://arxiv.org/abs/1807.03039
6 | """
7 | import os
8 | import numpy as np
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import torch.distributions as D
13 |
14 | # --------------------
15 | # Model component layers
16 | # --------------------
17 | class Actnorm(nn.Module):
18 | def __init__(self, param_dim=(1,2)):
19 | super().__init__()
20 | self.scale = nn.Parameter(torch.ones(param_dim))
21 | self.bias = nn.Parameter(torch.zeros(param_dim))
22 | self.register_buffer('initialized', torch.tensor(0).byte())
23 |
24 | def forward(self, x):
25 | if not self.initialized:
26 | # x.shape = (B, W)
27 | self.bias.squeeze().data.copy_(x.transpose(0,1).flatten(1).mean(1)).view_as(self.scale)
28 | self.scale.squeeze().data.copy_(x.transpose(0,1).flatten(1).std(1, unbiased=False) + 1e-6).view_as(self.bias)
29 | self.initialized += 1
30 |
31 | z = (x - self.bias) / self.scale
32 | logdet = - self.scale.abs().log().sum()
33 | return z, logdet
34 |
35 | def inverse(self, z):
36 | x = z * self.scale + self.bias
37 | logdet = self.scale.abs().log().sum()
38 | return x, logdet
39 |
40 |
41 | class Invertible1x1Conv(nn.Module):
42 | def __init__(self, dim=2):
43 | super().__init__()
44 |
45 | w = torch.randn(dim, dim)
46 | w = torch.linalg.qr(w)[0] # W^{-1} = W^T (only at initialization)
47 | self.w = nn.Parameter(w)
48 |
49 | def forward(self, x):
50 | logdet = torch.slogdet(self.w)[-1]
51 | return x @ self.w.t(), logdet # (WX)^T = X^TW^T = y^T.
52 |
53 | def inverse(self, z):
54 | w_inv = self.w.t().inverse()
55 | logdet = - torch.slogdet(self.w)[-1]
56 | return z @ w_inv, logdet
57 |
58 |
59 | class AffineCoupling(nn.Module):
60 | def __init__(self, dim=2, width=128):
61 | super().__init__()
62 | self.fc1 = nn.Linear(dim//2, width, bias=True)
63 | self.actnorm1 = Actnorm(param_dim=(1, width))
64 | self.fc2 = nn.Linear(width, width, bias=True)
65 | self.actnorm2 = Actnorm(param_dim=(1, width))
66 | self.fc3 = nn.Linear(width, dim, bias=True)
67 | self.log_scale_factor = nn.Parameter(torch.zeros(1,2))
68 |
69 | self.fc3.weight.data.zero_()
70 | self.fc3.bias.data.zero_()
71 |
72 | def forward(self, x):
73 | x_a, x_b = x.chunk(2, dim=1) # x.shape = [batch, 2]
74 |
75 | h = F.relu(self.actnorm1(self.fc1(x_b))[0])
76 | h = F.relu(self.actnorm2(self.fc2(h))[0])
77 | h = self.fc3(h) * self.log_scale_factor.exp()
78 | t = h[:,0::2] # shift; take even dimension(s)
79 | s = h[:,1::2] # scale; take odd dimension(s)
80 | s = torch.sigmoid(s + 2.)
81 |
82 | z_a = s * x_a + t
83 | z_b = x_b
84 | z = torch.cat([z_a, z_b], dim=1) # z.shape = [batch, 2]
85 | logdet = s.log().sum(1)
86 |
87 | return z, logdet
88 |
89 | def inverse(self, z):
90 | z_a, z_b = z.chunk(2, dim=1)
91 |
92 | h = F.relu(self.actnorm1(self.fc1(z_b))[0])
93 | h = F.relu(self.actnorm2(self.fc2(h))[0])
94 | h = self.fc3(h) * self.log_scale_factor.exp()
95 | t = h[:,0::2] # shift; take even dimension(s)
96 | s = h[:,1::2] # scale; take odd dimension(s)
97 | s = torch.sigmoid(s + 2.)
98 |
99 | x_a = (z_a - t) / s
100 | x_b = z_b
101 | x = torch.cat([x_a, x_b], dim=1)
102 |
103 | logdet = - s.log().sum(1)
104 | return x, logdet
105 |
106 | # --------------------
107 | # Container layers
108 | # --------------------
109 | class FlowSequential(nn.Sequential):
110 | def __init__(self, *args, **kwargs):
111 | super().__init__(*args, **kwargs)
112 |
113 | def forward(self, x):
114 | slogdet = 0.
115 | for module in self:
116 | x, logdet = module(x)
117 | slogdet = slogdet + logdet
118 | return x, slogdet
119 |
120 | def inverse(self, z):
121 | slogdet = 0.
122 | for module in reversed(self):
123 | z, logdet = module.inverse(z)
124 | slogdet = slogdet + logdet
125 | return z, slogdet
126 |
127 |
128 | class FlowStep(FlowSequential):
129 | """ One step (Actnorm -> Invertible 1x1 conv -> Affine coupling) """
130 | def __init__(self, dim=2, width=128):
131 | super().__init__(
132 | Actnorm(param_dim=(1,dim)),
133 | Invertible1x1Conv(dim=dim),
134 | AffineCoupling(dim=dim, width=width))
135 |
136 |
137 | class FlowLevel(nn.Module):
138 | """ One depth (e.g. 10) (FlowStep x 10) """
139 | def __init__(self, dim=2, width=128, depth=10):
140 | super().__init__()
141 | self.flowsteps = FlowSequential(*[FlowStep(dim, width) for _ in range(depth)]) # original: FlowStep(4*n_channels, width)
142 |
143 | def forward(self, x):
144 | z, logdet = self.flowsteps(x)
145 | return z, logdet
146 |
147 | def inverse(self, z):
148 | x, logdet = self.flowsteps.inverse(z)
149 | return x, logdet
150 |
151 |
152 | # --------------------
153 | # Model
154 | # --------------------
155 | class Glow(nn.Module):
156 | """ Glow multi-scale architecture with depth of flow K and number of levels L"""
157 | def __init__(self, width=128, depth=10, n_levels=1, data_dim=2):
158 | super().__init__()
159 |
160 | # (FlowStep x depth) x n_levels
161 | self.flowlevels = nn.ModuleList([FlowLevel(dim=data_dim, width=width, depth=depth) for i in range(n_levels)])
162 | self.flowstep = FlowSequential(*[FlowStep(dim=data_dim, width=width) for _ in range(depth)])
163 |
164 | # base distribution of the flow
165 | self.register_buffer('base_dist_mean', torch.zeros(2))
166 | self.register_buffer('base_dist_var', torch.eye(2))
167 |
168 | @property
169 | def base_dist(self):
170 | return D.MultivariateNormal(self.base_dist_mean, self.base_dist_var)
171 |
172 | def forward(self, x):
173 | slogdet = 0
174 | for m in self.flowlevels:
175 | z, logdet = m(x)
176 | slogdet = slogdet + logdet
177 | z, logdet = self.flowstep(z)
178 | slogdet = slogdet + logdet
179 | return z, slogdet
180 |
181 | def inverse(self, z=None, batch_size=32, z_std=1.):
182 | if z is None:
183 | z = z_std * self.base_dist.sample((batch_size,))
184 | x, slogdet = self.flowstep.inverse(z)
185 | for m in reversed(self.flowlevels):
186 | x, logdet = m.inverse(x)
187 | slogdet = slogdet + logdet
188 |
189 | # get logq(x̃), where x̃ = f^{-1}(z)
190 | logq_gen = (self.base_dist.log_prob(z) - slogdet).unsqueeze(1)
191 |
192 | return x, logq_gen
193 |
194 | def log_prob(self, x):
195 | z, logdet = self.forward(x)
196 | log_prob = self.base_dist.log_prob(z) + logdet
197 | return log_prob.unsqueeze(1)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Flow Contrastive Estimation (FCE)
2 |
3 | This is an implementation of [Flow Contrastive Estimation](https://openaccess.thecvf.com/content_CVPR_2020/html/Gao_Flow_Contrastive_Estimation_of_Energy-Based_Models_CVPR_2020_paper.html) in PyTorch on 2D dataset.
4 |
5 | ## Introduction
6 |
7 | ### Direct Estimation of Energy Model is Difficult
8 |
9 | Our problem is to estimate an energy based model (EBM)
10 |
11 | $$p\_\theta(x) = \frac{\exp[-f\_\theta(x)]}{Z(\theta)}$$
12 |
13 | where
14 |
15 | $$Z(\theta) = \int\exp[-f\_\theta(x)]dx$$
16 |
17 | is the normalizing constant. The energy model specifies a probability distribution on data space. The normalizing constant is very difficult to calculate since we have to sum over an exponential number of configurations.
18 |
19 | The energy based model is implemented in file [ebm.py](ebm.py).
20 |
21 | ### NCE: Teach EBM to Classify Data and Noise
22 |
23 | One approach to estimate EBM is through [Noise Contrastive Estimation (NCE)]( http://proceedings.mlr.press/v9/gutmann10a/gutmann10a.pdf ). In NCE, the normalizing constant is treated as a trainable parameter, and the model parameters are estimated by training the EBM to classify data and noise. Let $p\_{\mathrm{data}}(x)$ denote data distribution and let $q(x)$ denote some noise distribution. This amounts to maximize the following posterior log-likelihood of the classification:
24 |
25 | $$V(\theta) = \mathbb{E}\_{x\sim p\_{\text{data}}}\log\frac{p\_\theta(x)}{p\_\theta(x)+q(x)} + \mathbb{E}\_{\tilde{x}\sim q}\log\frac{q(\tilde{x})}{p\_\theta(\tilde{x}) + q(\tilde{x})}.$$
26 |
27 | ### FCE: Replace Noise in NCE with Flow Model
28 |
29 | In Flow Contrastive Estimation (FCE), we replace the noise $q(x)$ with a flow model $q_\alpha(x)$, and jointly train the two models by iteratively maximizing and minimizing the posterior log-likelihood of the classification:
30 |
31 | $$V(\alpha,\theta) = \mathbb{E}\_{x\sim p\_{\text{data}}}\log\frac{p\_\theta(x)}{p\_\theta(x)+q\_\alpha(x)} + \mathbb{E}\_{\tilde{x}\sim q\_\alpha}\log\frac{q\_\alpha(\tilde{x})}{p\_\theta(\tilde{x}) + q\_\alpha(\tilde{x})}.$$
32 |
33 | This objective is implemented as the `value` function in file [util.py](util.py).
34 |
35 | When the classification accuracy is below a threshold, the energy model is trained. Otherwise, the flow model is trained.
36 |
37 | In the paper, the authors choose [Glow](https://arxiv.org/abs/1807.03039) as the flow model. In this repository we implemented both Glow and [MAF](https://arxiv.org/abs/1705.07057) as the flow model.
38 |
39 | ## Training
40 |
41 | To train the model, do
42 |
43 | ```shell
44 | python train.py
45 | ```
46 |
47 | | Argument | Meaning |
48 | | ---------------------- | -------------------------------------------------------------------------------------------------------------------------------- |
49 | | `--seed=42` | random seed |
50 | | `--epoch=100` | training epoch |
51 | | `--flow=glow` | `glow` or `maf` to use as the flow model |
52 | | `--threshold=0.6` | threshold for alternate training |
53 | | `--batch=1000` | batch size |
54 | | `--dataset=8gaussians` | Available datasets:
- `8gaussians`
- `spiral`
- `2spirals`
- `checkerboard`
- `rings`
- `pinwheel` |
55 | | `--samples=500000` | training set size |
56 | | `--lr_ebm=1e-3` | Adam learning rate for EBM |
57 | | `--lr_flow=7e-4` | Adam learning rate for Flow model |
58 | | `--b1=0.9` | Adam gradient decay |
59 | | `--b2=0.999` | Adam gradient decay |
60 |
61 | ### Install wandb
62 |
63 | To run the script you need to install [Weights & Biases (wandb)](https://wandb.ai/site). It is an MLOps tool used to monitor the metrics during training. I find it very easy and convenient to use, and I encourage you to install and have a try as well.
64 |
65 | First, sign up on their website https://wandb.ai/site. You may use your GitHub account to sign up. Copy your API key.
66 |
67 | Next, install the Python package through
68 |
69 | ```shell
70 | pip install wandb
71 | ```
72 |
73 | When you start running your experiment, you may be asked to login. You simply paste your API key and hit enter. Now that the experiment is running, you can go to https://app.wandb.ai/home and visualize your experiment in real time.
74 |
75 | Otherwise, if you do not wish to use it, you can comment out all the `wandb` part in [train.py](train.py).
76 |
77 | ## Visualizations
78 |
79 | ### Density Plots
80 |
81 | The figure below shows result of Flow Contrastive Estimation, using MAF as the flow model. The left column has three data distributions. The middle column shows densities learned by MAF. Note that they are also the densities that the energy model is trying to distinguish from the true densities. The right column shows densities learned by EBM.
82 |
83 |
84 |
85 | For reference, here is the result presented in the FCE paper, showing learned Glow and EBM densities on three data distributions.
86 |
87 |
88 |
89 | ### MSE
90 |
91 | In case of the 8 Gaussian dataset, we have an analytical formula for the true data distribution. We can evaluate the MSE of the log density learned by the energy model versus the true data distribution. The plot below shows the MSE on the `8gaussians` training dataset.
92 |
93 |
94 |
95 | ### Value
96 |
97 | The figure below shows the negative of the value function during training. If both EBM $p_\theta$ and Flow $q\_\alpha$ are close to the data distribution $p_{\text{data}}$, then $p\_\theta\approx q\_\alpha\approx p\_{\text{data}}$ and the value should be approximately
98 |
99 | $$- V(\alpha,\theta)\approx -\left(\log\frac{1}{2}+\log\frac{1}{2}\right) = \log4 \approx 1.39.$$
100 |
101 |
102 |
103 | ### Accuracy
104 |
105 | In our experiment we choose `0.6` as the default threshold, and as we can see the classification accuracy of the EBM fluctuates around 0.6.
106 |
107 |
108 |
109 | ## Reference
110 |
111 | Gao, Ruiqi, et al. "Flow contrastive estimation of energy-based models." *Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition*. 2020.
--------------------------------------------------------------------------------
/flows/maf.py:
--------------------------------------------------------------------------------
1 | """
2 | Masked Autoregressive Flow for Density Estimation
3 | arXiv:1705.07057v4
4 | """
5 | import math
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.distributions as D
10 |
11 | # --------------------
12 | # Model layers and helpers
13 | # --------------------
14 | def create_masks(input_size, hidden_size, n_hidden, input_order='sequential', input_degrees=None):
15 | # MADE paper sec 4:
16 | # degrees of connections between layers -- ensure at most in_degree - 1 connections
17 | degrees = []
18 |
19 | # set input degrees to what is provided in args (the flipped order of the previous layer in a stack of mades);
20 | # else init input degrees based on strategy in input_order (sequential or random)
21 | if input_order == 'sequential':
22 | degrees += [torch.arange(input_size)] if input_degrees is None else [input_degrees]
23 | for _ in range(n_hidden + 1):
24 | degrees += [torch.arange(hidden_size) % (input_size - 1)]
25 | degrees += [torch.arange(input_size) % input_size - 1] if input_degrees is None else [input_degrees % input_size - 1]
26 | elif input_order == 'random':
27 | degrees += [torch.randperm(input_size)] if input_degrees is None else [input_degrees]
28 | for _ in range(n_hidden + 1):
29 | min_prev_degree = min(degrees[-1].min().item(), input_size - 1)
30 | degrees += [torch.randint(min_prev_degree, input_size, (hidden_size,))]
31 | min_prev_degree = min(degrees[-1].min().item(), input_size - 1)
32 | degrees += [torch.randint(min_prev_degree, input_size, (input_size,)) - 1] if input_degrees is None else [input_degrees - 1]
33 |
34 | # construct masks
35 | masks = []
36 | for (d0, d1) in zip(degrees[:-1], degrees[1:]):
37 | masks += [(d1.unsqueeze(-1) >= d0.unsqueeze(0)).float()]
38 |
39 | return masks, degrees[0]
40 |
41 |
42 |
43 | class MaskedLinear(nn.Linear):
44 | """ MADE building block layer """
45 | def __init__(self, input_size, n_outputs, mask, cond_label_size=None):
46 | super().__init__(input_size, n_outputs)
47 |
48 | self.register_buffer('mask', mask)
49 |
50 | self.cond_label_size = cond_label_size
51 | if cond_label_size is not None:
52 | self.cond_weight = nn.Parameter(torch.rand(n_outputs, cond_label_size) / math.sqrt(cond_label_size))
53 |
54 | def forward(self, x, y=None):
55 | out = F.linear(x, self.weight * self.mask, self.bias)
56 | if y is not None:
57 | out = out + F.linear(y, self.cond_weight)
58 | return out
59 |
60 | def extra_repr(self):
61 | return 'in_features={}, out_features={}, bias={}'.format(
62 | self.in_features, self.out_features, self.bias is not None
63 | ) + (self.cond_label_size != None) * ', cond_features={}'.format(self.cond_label_size)
64 |
65 |
66 |
67 | class FlowSequential(nn.Sequential):
68 | """ Container for layers of a normalizing flow """
69 | def forward(self, x, y):
70 | sum_log_abs_det_jacobians = 0
71 | for module in self:
72 | x, log_abs_det_jacobian = module(x, y)
73 | sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + log_abs_det_jacobian
74 | return x, sum_log_abs_det_jacobians
75 |
76 | def inverse(self, u, y):
77 | sum_log_abs_det_jacobians = 0
78 | for module in reversed(self):
79 | u, log_abs_det_jacobian = module.inverse(u, y)
80 | sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + log_abs_det_jacobian
81 | return u, sum_log_abs_det_jacobians
82 |
83 |
84 | class BatchNorm(nn.Module):
85 | """ RealNVP BatchNorm layer """
86 | def __init__(self, input_size, momentum=0.9, eps=1e-5):
87 | super().__init__()
88 | self.momentum = momentum
89 | self.eps = eps
90 |
91 | self.log_gamma = nn.Parameter(torch.zeros(input_size))
92 | self.beta = nn.Parameter(torch.zeros(input_size))
93 |
94 | self.register_buffer('running_mean', torch.zeros(input_size))
95 | self.register_buffer('running_var', torch.ones(input_size))
96 |
97 | def forward(self, x, cond_y=None):
98 | if self.training:
99 | self.batch_mean = x.mean(0)
100 | self.batch_var = x.var(0) # note MAF paper uses biased variance estimate; ie x.var(0, unbiased=False)
101 |
102 | # update running mean
103 | self.running_mean.mul_(self.momentum).add_(self.batch_mean.data * (1 - self.momentum))
104 | self.running_var.mul_(self.momentum).add_(self.batch_var.data * (1 - self.momentum))
105 |
106 | mean = self.batch_mean
107 | var = self.batch_var
108 | else:
109 | mean = self.running_mean
110 | var = self.running_var
111 |
112 | # compute normalized input (cf original batch norm paper algo 1)
113 | x_hat = (x - mean) / torch.sqrt(var + self.eps)
114 | y = self.log_gamma.exp() * x_hat + self.beta
115 |
116 | # compute log_abs_det_jacobian (cf RealNVP paper)
117 | log_abs_det_jacobian = self.log_gamma - 0.5 * torch.log(var + self.eps)
118 | # print('in sum log var {:6.3f} ; out sum log var {:6.3f}; sum log det {:8.3f}; mean log_gamma {:5.3f}; mean beta {:5.3f}'.format(
119 | # (var + self.eps).log().sum().data.numpy(), y.var(0).log().sum().data.numpy(), log_abs_det_jacobian.mean(0).item(), self.log_gamma.mean(), self.beta.mean()))
120 | return y, log_abs_det_jacobian.expand_as(x)
121 |
122 | def inverse(self, y, cond_y=None):
123 | if self.training:
124 | mean = self.batch_mean
125 | var = self.batch_var
126 | else:
127 | mean = self.running_mean
128 | var = self.running_var
129 |
130 | x_hat = (y - self.beta) * torch.exp(-self.log_gamma)
131 | x = x_hat * torch.sqrt(var + self.eps) + mean
132 |
133 | log_abs_det_jacobian = 0.5 * torch.log(var + self.eps) - self.log_gamma
134 |
135 | return x, log_abs_det_jacobian.expand_as(x)
136 |
137 |
138 | # --------------------
139 | # Models
140 | # --------------------
141 | class MADE(nn.Module):
142 | def __init__(self, input_size, hidden_size, n_hidden, cond_label_size=None, activation='relu', input_order='sequential', input_degrees=None):
143 | """
144 | Args:
145 | input_size -- scalar; dim of inputs
146 | hidden_size -- scalar; dim of hidden layers
147 | n_hidden -- scalar; number of hidden layers
148 | activation -- str; activation function to use
149 | input_order -- str or tensor; variable order for creating the autoregressive masks (sequential|random)
150 | or the order flipped from the previous layer in a stack of mades
151 | conditional -- bool; whether model is conditional
152 | """
153 | super().__init__()
154 | # base distribution for calculation of log prob under the model
155 | self.register_buffer('base_dist_mean', torch.zeros(input_size))
156 | self.register_buffer('base_dist_var', torch.ones(input_size))
157 |
158 | # create masks
159 | masks, self.input_degrees = create_masks(input_size, hidden_size, n_hidden, input_order, input_degrees)
160 |
161 | # setup activation
162 | if activation == 'relu':
163 | activation_fn = nn.ReLU()
164 | elif activation == 'tanh':
165 | activation_fn = nn.Tanh()
166 | else:
167 | raise ValueError('Check activation function.')
168 |
169 | # construct model
170 | self.net_input = MaskedLinear(input_size, hidden_size, masks[0], cond_label_size)
171 | self.net = []
172 | for m in masks[1:-1]:
173 | self.net += [activation_fn, MaskedLinear(hidden_size, hidden_size, m)]
174 | self.net += [activation_fn, MaskedLinear(hidden_size, 2 * input_size, masks[-1].repeat(2,1))]
175 | self.net = nn.Sequential(*self.net)
176 |
177 | @property
178 | def base_dist(self):
179 | return D.Normal(self.base_dist_mean, self.base_dist_var)
180 |
181 | def forward(self, x, y=None):
182 | # MAF eq 4 -- return mean and log std
183 | mu, alpha = self.net(self.net_input(x, y)).chunk(chunks=2, dim=1)
184 | u = (x - mu) * torch.exp(-alpha)
185 | # MAF eq 5
186 | logdet = - alpha
187 | return u, logdet
188 |
189 | def inverse(self, u, y=None, sum_log_abs_det_jacobians=None):
190 | # MAF eq 3
191 | # D = u.shape[1]
192 | x = torch.zeros_like(u)
193 | # run through reverse model
194 | for i in self.input_degrees:
195 | mu, alpha = self.net(self.net_input(x.clone(), y)).chunk(chunks=2, dim=1)
196 | x[:,i] = mu[:,i] + u[:,i] * torch.exp(alpha[:,i])
197 | logdet = alpha
198 | return x, logdet
199 |
200 | def log_prob(self, x, y=None):
201 | u, logdet = self.forward(x, y)
202 | return torch.sum(self.base_dist.log_prob(u) + logdet, dim=1, keepdim=True)
203 |
204 |
205 |
206 | class MAF(nn.Module):
207 | def __init__(self, n_blocks, input_size, hidden_size, n_hidden, cond_label_size=None, activation='relu', input_order='sequential', batch_norm=True):
208 | super().__init__()
209 | # base distribution for calculation of log prob under the model
210 | self.register_buffer('base_dist_mean', torch.zeros(input_size))
211 | self.register_buffer('base_dist_var', torch.ones(input_size))
212 |
213 |
214 | # construct model
215 | modules = []
216 | self.input_degrees = None
217 | for i in range(n_blocks):
218 | modules += [MADE(input_size, hidden_size, n_hidden, cond_label_size, activation, input_order, self.input_degrees)]
219 | self.input_degrees = modules[-1].input_degrees.flip(0)
220 | # modules += batch_norm * [BatchNorm(input_size)]
221 |
222 | self.net = FlowSequential(*modules)
223 |
224 | @property
225 | def base_dist(self):
226 | return D.Normal(self.base_dist_mean, self.base_dist_var)
227 |
228 | def forward(self, x, y=None):
229 | return self.net(x, y)
230 |
231 | def inverse(self, u, y=None):
232 | gen, logdet = self.net.inverse(u, y)
233 |
234 | # get logq(x̃), where x̃ = f^{-1}(z)
235 | logq_gen = torch.sum(self.base_dist.log_prob(u) - logdet, dim=1, keepdim=True)
236 |
237 | return gen, logq_gen
238 |
239 |
240 | def log_prob(self, x, y=None):
241 | u, logdet = self.forward(x, y)
242 | return torch.sum(self.base_dist.log_prob(u) + logdet, dim=1, keepdim=True)
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import DataLoader
5 | import torch.distributions as D
6 | import torch.nn.functional as F
7 |
8 | import matplotlib
9 | matplotlib.use('Agg')
10 | import matplotlib.pyplot as plt
11 | font = {'family' : 'sans-serif',
12 | 'weight' : 'bold',
13 | 'size' : 16}
14 | matplotlib.rc('font', **font)
15 |
16 | # ------------------------------
17 | # FCE Value Function
18 | # ------------------------------
19 | def value(energy, flow, x, z, maximize=True):
20 | gen, logq_gen = flow.inverse(z) # the second term is logq(x̃)
21 | logp_x = energy(x) # logp(x)
22 | logp_gen = energy(gen) #logp(x̃)
23 | logq_x = flow.log_prob(x) # logq(x)
24 |
25 | value_x = logp_x - torch.logsumexp(torch.cat([logp_x, logq_x], dim=1), dim=1, keepdim=True) # logp(x)/(logp(x) + logq(x))
26 | value_gen = logq_gen - torch.logsumexp(torch.cat([logp_gen, logq_gen], dim=1), dim=1, keepdim=True) # logq(x̃)/(logp(x̃) + logq(x̃))
27 |
28 | v = value_x.mean() + value_gen.mean()
29 |
30 | # calculate accuracy
31 | r_x = torch.sigmoid(logp_x - logq_x)
32 | r_gen = torch.sigmoid(logq_gen - logp_gen)
33 | acc = ((r_x >= 1/2).sum() + (r_gen > 1/2).sum()).cpu().numpy() / (len(r_x) + len(r_gen))
34 |
35 | if maximize:
36 | return -v, acc
37 | else:
38 | return v, acc
39 |
40 | # ------------------------------
41 | # MIXED GAUSSIAN DENSITY
42 | # ------------------------------
43 | class MixedGaussian():
44 | def __init__(self, device):
45 | r = 2. * math.sqrt(2)
46 | self.covariance_matrix = (1/8 * torch.eye(2)).to(device) # variance on the diagonal, not standard deviation
47 | self.m1 = D.MultivariateNormal(loc=torch.tensor([r, 0.], device=device), covariance_matrix=self.covariance_matrix)
48 | self.m2 = D.MultivariateNormal(loc=torch.tensor([-r, 0.], device=device), covariance_matrix=self.covariance_matrix)
49 | self.m3 = D.MultivariateNormal(loc=torch.tensor([0., r], device=device), covariance_matrix=self.covariance_matrix)
50 | self.m4 = D.MultivariateNormal(loc=torch.tensor([0., -r], device=device), covariance_matrix=self.covariance_matrix)
51 | self.m5 = D.MultivariateNormal(loc=torch.tensor([2., 2.], device=device), covariance_matrix=self.covariance_matrix)
52 | self.m6 = D.MultivariateNormal(loc=torch.tensor([-2., 2.], device=device), covariance_matrix=self.covariance_matrix)
53 | self.m7 = D.MultivariateNormal(loc=torch.tensor([2., -2.], device=device), covariance_matrix=self.covariance_matrix)
54 | self.m8 = D.MultivariateNormal(loc=torch.tensor([-2., -2.], device=device), covariance_matrix=self.covariance_matrix)
55 |
56 | def log_prob(self, x):
57 | log_probs = torch.cat([self.m1.log_prob(x).unsqueeze(1),
58 | self.m2.log_prob(x).unsqueeze(1),
59 | self.m3.log_prob(x).unsqueeze(1),
60 | self.m4.log_prob(x).unsqueeze(1),
61 | self.m5.log_prob(x).unsqueeze(1),
62 | self.m6.log_prob(x).unsqueeze(1),
63 | self.m7.log_prob(x).unsqueeze(1),
64 | self.m8.log_prob(x).unsqueeze(1)], dim=1)
65 | out = - math.log(8.) + torch.logsumexp(log_probs, dim=1, keepdim=True)
66 | return out
67 |
68 | # MSE between EBM and a true distribution.
69 | def mse(ebm, true_dist, data_batch):
70 | with torch.no_grad():
71 | ebm_outputs = ebm(data_batch)
72 | true_values = true_dist.log_prob(data_batch)
73 | return F.mse_loss(input=ebm_outputs, target=true_values).item()
74 |
75 | #-------------------------------------------
76 | # DATA
77 | #-------------------------------------------
78 | def get_data(args):
79 | dataset = sample_2d_data(dataset=args.dataset, n_samples=args.samples)
80 | dataloader = DataLoader(dataset, batch_size=args.batch, shuffle=True)
81 | return dataset, dataloader
82 |
83 |
84 | def sample_2d_data(dataset='8gaussians', n_samples=50000):
85 |
86 | z = torch.randn(n_samples, 2)
87 |
88 | if dataset == '8gaussians':
89 | scale = 4
90 | sq2 = 1/math.sqrt(2)
91 | centers = [(1,0), (-1,0), (0,1), (0,-1), (sq2,sq2), (-sq2,sq2), (sq2,-sq2), (-sq2,-sq2)]
92 | centers = torch.tensor([(scale * x, scale * y) for x,y in centers])
93 | return sq2 * (0.5 * z + centers[torch.randint(len(centers), size=(n_samples,))])
94 | # return 0.05 * z + centers[torch.randint(len(centers), size=(n_samples,))]
95 |
96 | elif dataset == '2spirals':
97 | n = torch.sqrt(torch.rand(n_samples // 2)) * 540 * (2 * math.pi) / 360
98 | d1x = - torch.cos(n) * n + torch.rand(n_samples // 2) * 0.5
99 | d1y = torch.sin(n) * n + torch.rand(n_samples // 2) * 0.5
100 | x = torch.cat([torch.stack([ d1x, d1y], dim=1),
101 | torch.stack([-d1x, -d1y], dim=1)], dim=0) / 3
102 | return x + 0.1*z
103 |
104 | elif dataset == 'spiral':
105 | n = torch.sqrt(torch.rand(n_samples)) * 540 * (2 * math.pi) / 360
106 | d1x = - torch.cos(n) * n + torch.rand(n_samples) * 0.5
107 | d1y = torch.sin(n) * n + torch.rand(n_samples) * 0.5
108 | print(d1x.shape)
109 | x = torch.stack([ d1x, d1y], dim=1) / 3
110 | print(x.shape)
111 | return x + 0.1*z
112 |
113 | elif dataset == 'checkerboard':
114 | x1 = torch.rand(n_samples) * 4 - 2
115 | x2_ = torch.rand(n_samples) - torch.randint(0, 2, (n_samples,), dtype=torch.float) * 2
116 | x2 = x2_ + x1.floor() % 2
117 | return torch.stack([x1, x2], dim=1) * 2
118 |
119 | elif dataset == 'rings':
120 | n_samples4 = n_samples3 = n_samples2 = n_samples // 4
121 | n_samples1 = n_samples - n_samples4 - n_samples3 - n_samples2
122 | # so as not to have the first point = last point, set endpoint=False in np; here shifted by one
123 | linspace4 = torch.linspace(0, 2 * math.pi, n_samples4 + 1)[:-1]
124 | linspace3 = torch.linspace(0, 2 * math.pi, n_samples3 + 1)[:-1]
125 | linspace2 = torch.linspace(0, 2 * math.pi, n_samples2 + 1)[:-1]
126 | linspace1 = torch.linspace(0, 2 * math.pi, n_samples1 + 1)[:-1]
127 | circ4_x = torch.cos(linspace4)
128 | circ4_y = torch.sin(linspace4)
129 | circ3_x = torch.cos(linspace4) * 0.75
130 | circ3_y = torch.sin(linspace3) * 0.75
131 | circ2_x = torch.cos(linspace2) * 0.5
132 | circ2_y = torch.sin(linspace2) * 0.5
133 | circ1_x = torch.cos(linspace1) * 0.25
134 | circ1_y = torch.sin(linspace1) * 0.25
135 |
136 | x = torch.stack([torch.cat([circ4_x, circ3_x, circ2_x, circ1_x]),
137 | torch.cat([circ4_y, circ3_y, circ2_y, circ1_y])], dim=1) * 3.0
138 |
139 | # random sample
140 | x = x[torch.randint(0, n_samples, size=(n_samples,))]
141 |
142 | # Add noise
143 | return x + torch.normal(mean=torch.zeros_like(x), std=0.08*torch.ones_like(x))
144 |
145 | elif dataset == "pinwheel":
146 | rng = np.random.RandomState()
147 | radial_std = 0.3
148 | tangential_std = 0.1
149 | num_classes = 5
150 | num_per_class = n_samples // 5
151 | rate = 0.25
152 | rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False)
153 | features = rng.randn(num_classes*num_per_class, 2) * np.array([radial_std, tangential_std])
154 | # features = np.random.randn(num_classes*num_per_class, 2) * np.array([radial_std, tangential_std])
155 | features[:, 0] += 1.
156 | labels = np.repeat(np.arange(num_classes), num_per_class)
157 |
158 | angles = rads[labels] + rate * np.exp(features[:, 0])
159 | rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)])
160 | rotations = np.reshape(rotations.T, (-1, 2, 2))
161 |
162 | data = 2 * rng.permutation(np.einsum("ti,tij->tj", features, rotations))
163 | # data = 2 * np.random.permutation(np.einsum("ti,tij->tj", features, rotations))
164 | return torch.as_tensor(data, dtype=torch.float32)
165 |
166 | else:
167 | raise RuntimeError('Invalid `dataset` to sample from.')
168 |
169 | # --------------------
170 | # Plotting
171 | # --------------------
172 |
173 | @torch.no_grad()
174 | def plot(dataset, energy, flow, epoch, device):
175 | n_pts = 1000
176 | range_lim = 4
177 | # construct test points
178 | test_grid = setup_grid(range_lim, n_pts, device)
179 |
180 | # plot
181 | fig, axs = plt.subplots(2, 2, figsize=(8,8), subplot_kw={'aspect': 'equal'})
182 | plot_samples(dataset, axs[0][0], range_lim, n_pts)
183 | plot_flow(flow, axs[0][1], test_grid, n_pts)
184 | plot_flow_samples(flow, axs[1][0], range_lim, n_pts)
185 | plot_energy(energy, axs[1][1], test_grid, n_pts)
186 |
187 | # format
188 | for ax in plt.gcf().axes: format_ax(ax, range_lim)
189 | plt.tight_layout(pad=2.0)
190 |
191 | # save
192 | print('Saving image to images/....')
193 | plt.savefig('images/epoch_{}.png'.format(epoch))
194 | plt.close()
195 |
196 |
197 | def setup_grid(range_lim, n_pts, device):
198 | x = torch.linspace(-range_lim, range_lim, n_pts)
199 | xx, yy = torch.meshgrid((x, x))
200 | zz = torch.stack((xx.flatten(), yy.flatten()), dim=1)
201 | return xx, yy, zz.to(device)
202 |
203 | def plot_samples(dataset, ax, range_lim, n_pts):
204 | samples = dataset.numpy()
205 | ax.hist2d(samples[:,0], samples[:,1], range=[[-range_lim, range_lim], [-range_lim, range_lim]], bins=n_pts, cmap=plt.cm.jet)
206 | ax.set_title('Data')
207 |
208 | def plot_energy(energy, ax, test_grid, n_pts):
209 | xx, yy, zz = test_grid
210 | log_prob = energy(zz)
211 | prob = log_prob.exp().cpu()
212 | ax.pcolormesh(xx, yy, prob.view(n_pts,n_pts), cmap=plt.cm.jet)
213 | ax.set_facecolor(plt.cm.jet(0.))
214 | ax.set_title('EBM')
215 |
216 | def plot_flow(flow, ax, test_grid, n_pts):
217 | flow.eval()
218 | xx, yy, zz = test_grid
219 | log_prob = flow.log_prob(zz)
220 | prob = log_prob.exp().cpu()
221 | ax.pcolormesh(xx, yy, prob.view(n_pts,n_pts), cmap=plt.cm.jet)
222 | ax.set_facecolor(plt.cm.jet(0.))
223 | ax.set_title('Flow')
224 |
225 | def plot_flow_samples(flow, ax, range_lim, n_pts):
226 | z = flow.base_dist.sample((10000,))
227 | samples, _ = flow.inverse(z)
228 | samples = samples.cpu().numpy()
229 | ax.hist2d(samples[:,0], samples[:,1], range=[[-range_lim, range_lim], [-range_lim, range_lim]], bins=n_pts, cmap=plt.cm.jet)
230 | ax.set_title('Flow-samples')
231 |
232 | def format_ax(ax, range_lim):
233 | ax.set_xlim(-range_lim, range_lim)
234 | ax.set_ylim(-range_lim, range_lim)
235 | ax.get_xaxis().set_visible(False)
236 | ax.get_yaxis().set_visible(False)
237 | ax.invert_yaxis()
--------------------------------------------------------------------------------