├── imgs ├── acc.png ├── mse.png ├── fce-maf.png ├── value.png └── fce-glow.png ├── .gitattributes ├── ebm.py ├── LICENSE ├── train.py ├── flows ├── glow.py └── maf.py ├── README.md └── util.py /imgs/acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/volagold/fce-2d/HEAD/imgs/acc.png -------------------------------------------------------------------------------- /imgs/mse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/volagold/fce-2d/HEAD/imgs/mse.png -------------------------------------------------------------------------------- /imgs/fce-maf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/volagold/fce-2d/HEAD/imgs/fce-maf.png -------------------------------------------------------------------------------- /imgs/value.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/volagold/fce-2d/HEAD/imgs/value.png -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /imgs/fce-glow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/volagold/fce-2d/HEAD/imgs/fce-glow.png -------------------------------------------------------------------------------- /ebm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class EBM(nn.Module): 6 | def __init__(self): 7 | super(EBM, self).__init__() 8 | # The normalizing constant logZ(θ) 9 | self.c = nn.Parameter(torch.tensor([1.], requires_grad=True)) 10 | 11 | self.f = nn.Sequential( 12 | nn.Linear(2, 128), 13 | nn.LeakyReLU(0.2, inplace=True), 14 | nn.Linear(128, 128), 15 | nn.LeakyReLU(0.2, inplace=True), 16 | nn.Linear(128, 1), 17 | ) 18 | 19 | def forward(self, x): 20 | log_prob = - self.f(x) - self.c 21 | return log_prob -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 franli 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Flow Contrastive Estimation (FCE) on 2D dataset. 3 | https://arxiv.org/abs/1912.00589 4 | """ 5 | import os 6 | import math 7 | import argparse 8 | import torch 9 | from ebm import EBM 10 | import util 11 | 12 | import wandb 13 | 14 | 15 | parser = argparse.ArgumentParser(description='Flow Contrastive Estimation of Energy Based Model') 16 | parser.add_argument('--seed', default=42, type=int, help='seed') 17 | parser.add_argument('--epoch', default=100, type=int, help='number of training epochs') 18 | parser.add_argument('--flow', default='glow', type=str, help='Flow model to use') 19 | parser.add_argument('--threshold', default=0.6, type=float, help='threshold for alternate training') 20 | parser.add_argument('--batch', default=1000, type=int, help='batch size') 21 | parser.add_argument('--dataset', default='8gaussians', type=str, choices=['8gaussians', 'spiral', '2spirals', 'checkerboard', 'rings', 'pinwheel'], help='2D dataset to use') 22 | parser.add_argument('--samples', default=500000, type=int, help='number of 2D samples for training') 23 | parser.add_argument('--lr_ebm', default=1e-3, type=float, help='learning rate for EBM') 24 | parser.add_argument('--lr_flow', default=7e-4, type=float, help='learning rate for Flow') 25 | parser.add_argument('--b1', type=float, default=0.9, help='adam: decay of first order momentum of gradient') 26 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient') 27 | args = parser.parse_args() 28 | 29 | wandb.init(project='FCE-2d') 30 | wandb.config.update(args) 31 | 32 | device = torch.device('cuda') 33 | torch.manual_seed(args.seed) 34 | # ------------------------------ 35 | # I. MODELS 36 | # ------------------------------ 37 | energy = EBM().to(device) 38 | if args.flow == 'glow': 39 | from flows.glow import Glow 40 | flow = Glow(width=64, depth=5, n_levels=1, data_dim=2).to(device) 41 | elif args.flow == 'maf': 42 | from flows.maf import MAF 43 | flow = MAF(n_blocks=5, input_size=2, hidden_size=100, n_hidden=1).to(device) 44 | # ------------------------------ 45 | # II. OPTIMIZERS 46 | # ------------------------------ 47 | optim_energy = torch.optim.Adam(energy.parameters(), lr=args.lr_ebm, betas=(args.b1, args.b2)) 48 | optim_flow = torch.optim.Adam(flow.parameters(), lr=args.lr_flow, betas=(args.b1, args.b2)) 49 | # ------------------------------ 50 | # III. DATA LOADER 51 | # ------------------------------ 52 | dataset, dataloader = util.get_data(args) 53 | dataset = dataset.to(device) 54 | # ------------------------------ 55 | # IV. TRAINING 56 | # ------------------------------ 57 | wandb.watch(energy) 58 | wandb.watch(flow) 59 | 60 | def main(args): 61 | train_energy = True 62 | for epoch in range(args.epoch): 63 | for i, x in enumerate(dataloader): 64 | x = x.to(device) 65 | # ----------------------------- 66 | # Generate noise 67 | # ----------------------------- 68 | z = flow.base_dist.sample((args.batch,)) 69 | # ----------------------------- 70 | # Train Energy Model 71 | # ----------------------------- 72 | if train_energy: 73 | optim_energy.zero_grad() 74 | loss_energy, acc = util.value(energy, flow, x, z, maximize=True) 75 | loss_energy.backward() 76 | optim_energy.step() 77 | # ----------------------------- 78 | # Train Flow Model 79 | # ----------------------------- 80 | else: 81 | optim_flow.zero_grad() 82 | loss_flow, acc = util.value(energy, flow, x, z, maximize=False) 83 | loss_flow.backward() 84 | optim_flow.step() 85 | 86 | wandb.log({'epoch': epoch, 87 | 'value': loss_energy.item() if train_energy else -loss_flow.item(), 88 | 'acc': acc, 89 | 'mse': util.mse(energy, util.MixedGaussian(device=device), dataset), # comment out if not using 8gaussians 90 | }) 91 | 92 | 93 | if acc > args.threshold: 94 | train_energy = False 95 | else: 96 | train_energy = True 97 | 98 | 99 | # Save checkpoint 100 | print('Saving models...') 101 | state = { 102 | 'energy': energy.state_dict(), 103 | 'flow': flow.state_dict(), 104 | 'value': loss_energy, 105 | 'epoch': epoch, 106 | } 107 | os.makedirs('ckpts', exist_ok=True) 108 | ckpts = 'ckpts/fce-{}-2d-{}.pth.tar'.format(args.flow, args.dataset) 109 | torch.save(state, ckpts) 110 | 111 | # visualization 112 | # util.plot(dataset, energy, flow, epoch, device) 113 | 114 | 115 | 116 | 117 | 118 | if __name__ == '__main__': 119 | print(args) 120 | main(args) 121 | -------------------------------------------------------------------------------- /flows/glow.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Glow[1] model in PyTorch for 2D dataset. Adapted from 3 | https://github.com/kamenbliznashki/normalizing_flows/blob/master/glow.py 4 | 5 | [1] https://arxiv.org/abs/1807.03039 6 | """ 7 | import os 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.distributions as D 13 | 14 | # -------------------- 15 | # Model component layers 16 | # -------------------- 17 | class Actnorm(nn.Module): 18 | def __init__(self, param_dim=(1,2)): 19 | super().__init__() 20 | self.scale = nn.Parameter(torch.ones(param_dim)) 21 | self.bias = nn.Parameter(torch.zeros(param_dim)) 22 | self.register_buffer('initialized', torch.tensor(0).byte()) 23 | 24 | def forward(self, x): 25 | if not self.initialized: 26 | # x.shape = (B, W) 27 | self.bias.squeeze().data.copy_(x.transpose(0,1).flatten(1).mean(1)).view_as(self.scale) 28 | self.scale.squeeze().data.copy_(x.transpose(0,1).flatten(1).std(1, unbiased=False) + 1e-6).view_as(self.bias) 29 | self.initialized += 1 30 | 31 | z = (x - self.bias) / self.scale 32 | logdet = - self.scale.abs().log().sum() 33 | return z, logdet 34 | 35 | def inverse(self, z): 36 | x = z * self.scale + self.bias 37 | logdet = self.scale.abs().log().sum() 38 | return x, logdet 39 | 40 | 41 | class Invertible1x1Conv(nn.Module): 42 | def __init__(self, dim=2): 43 | super().__init__() 44 | 45 | w = torch.randn(dim, dim) 46 | w = torch.linalg.qr(w)[0] # W^{-1} = W^T (only at initialization) 47 | self.w = nn.Parameter(w) 48 | 49 | def forward(self, x): 50 | logdet = torch.slogdet(self.w)[-1] 51 | return x @ self.w.t(), logdet # (WX)^T = X^TW^T = y^T. 52 | 53 | def inverse(self, z): 54 | w_inv = self.w.t().inverse() 55 | logdet = - torch.slogdet(self.w)[-1] 56 | return z @ w_inv, logdet 57 | 58 | 59 | class AffineCoupling(nn.Module): 60 | def __init__(self, dim=2, width=128): 61 | super().__init__() 62 | self.fc1 = nn.Linear(dim//2, width, bias=True) 63 | self.actnorm1 = Actnorm(param_dim=(1, width)) 64 | self.fc2 = nn.Linear(width, width, bias=True) 65 | self.actnorm2 = Actnorm(param_dim=(1, width)) 66 | self.fc3 = nn.Linear(width, dim, bias=True) 67 | self.log_scale_factor = nn.Parameter(torch.zeros(1,2)) 68 | 69 | self.fc3.weight.data.zero_() 70 | self.fc3.bias.data.zero_() 71 | 72 | def forward(self, x): 73 | x_a, x_b = x.chunk(2, dim=1) # x.shape = [batch, 2] 74 | 75 | h = F.relu(self.actnorm1(self.fc1(x_b))[0]) 76 | h = F.relu(self.actnorm2(self.fc2(h))[0]) 77 | h = self.fc3(h) * self.log_scale_factor.exp() 78 | t = h[:,0::2] # shift; take even dimension(s) 79 | s = h[:,1::2] # scale; take odd dimension(s) 80 | s = torch.sigmoid(s + 2.) 81 | 82 | z_a = s * x_a + t 83 | z_b = x_b 84 | z = torch.cat([z_a, z_b], dim=1) # z.shape = [batch, 2] 85 | logdet = s.log().sum(1) 86 | 87 | return z, logdet 88 | 89 | def inverse(self, z): 90 | z_a, z_b = z.chunk(2, dim=1) 91 | 92 | h = F.relu(self.actnorm1(self.fc1(z_b))[0]) 93 | h = F.relu(self.actnorm2(self.fc2(h))[0]) 94 | h = self.fc3(h) * self.log_scale_factor.exp() 95 | t = h[:,0::2] # shift; take even dimension(s) 96 | s = h[:,1::2] # scale; take odd dimension(s) 97 | s = torch.sigmoid(s + 2.) 98 | 99 | x_a = (z_a - t) / s 100 | x_b = z_b 101 | x = torch.cat([x_a, x_b], dim=1) 102 | 103 | logdet = - s.log().sum(1) 104 | return x, logdet 105 | 106 | # -------------------- 107 | # Container layers 108 | # -------------------- 109 | class FlowSequential(nn.Sequential): 110 | def __init__(self, *args, **kwargs): 111 | super().__init__(*args, **kwargs) 112 | 113 | def forward(self, x): 114 | slogdet = 0. 115 | for module in self: 116 | x, logdet = module(x) 117 | slogdet = slogdet + logdet 118 | return x, slogdet 119 | 120 | def inverse(self, z): 121 | slogdet = 0. 122 | for module in reversed(self): 123 | z, logdet = module.inverse(z) 124 | slogdet = slogdet + logdet 125 | return z, slogdet 126 | 127 | 128 | class FlowStep(FlowSequential): 129 | """ One step (Actnorm -> Invertible 1x1 conv -> Affine coupling) """ 130 | def __init__(self, dim=2, width=128): 131 | super().__init__( 132 | Actnorm(param_dim=(1,dim)), 133 | Invertible1x1Conv(dim=dim), 134 | AffineCoupling(dim=dim, width=width)) 135 | 136 | 137 | class FlowLevel(nn.Module): 138 | """ One depth (e.g. 10) (FlowStep x 10) """ 139 | def __init__(self, dim=2, width=128, depth=10): 140 | super().__init__() 141 | self.flowsteps = FlowSequential(*[FlowStep(dim, width) for _ in range(depth)]) # original: FlowStep(4*n_channels, width) 142 | 143 | def forward(self, x): 144 | z, logdet = self.flowsteps(x) 145 | return z, logdet 146 | 147 | def inverse(self, z): 148 | x, logdet = self.flowsteps.inverse(z) 149 | return x, logdet 150 | 151 | 152 | # -------------------- 153 | # Model 154 | # -------------------- 155 | class Glow(nn.Module): 156 | """ Glow multi-scale architecture with depth of flow K and number of levels L""" 157 | def __init__(self, width=128, depth=10, n_levels=1, data_dim=2): 158 | super().__init__() 159 | 160 | # (FlowStep x depth) x n_levels 161 | self.flowlevels = nn.ModuleList([FlowLevel(dim=data_dim, width=width, depth=depth) for i in range(n_levels)]) 162 | self.flowstep = FlowSequential(*[FlowStep(dim=data_dim, width=width) for _ in range(depth)]) 163 | 164 | # base distribution of the flow 165 | self.register_buffer('base_dist_mean', torch.zeros(2)) 166 | self.register_buffer('base_dist_var', torch.eye(2)) 167 | 168 | @property 169 | def base_dist(self): 170 | return D.MultivariateNormal(self.base_dist_mean, self.base_dist_var) 171 | 172 | def forward(self, x): 173 | slogdet = 0 174 | for m in self.flowlevels: 175 | z, logdet = m(x) 176 | slogdet = slogdet + logdet 177 | z, logdet = self.flowstep(z) 178 | slogdet = slogdet + logdet 179 | return z, slogdet 180 | 181 | def inverse(self, z=None, batch_size=32, z_std=1.): 182 | if z is None: 183 | z = z_std * self.base_dist.sample((batch_size,)) 184 | x, slogdet = self.flowstep.inverse(z) 185 | for m in reversed(self.flowlevels): 186 | x, logdet = m.inverse(x) 187 | slogdet = slogdet + logdet 188 | 189 | # get logq(x̃), where x̃ = f^{-1}(z) 190 | logq_gen = (self.base_dist.log_prob(z) - slogdet).unsqueeze(1) 191 | 192 | return x, logq_gen 193 | 194 | def log_prob(self, x): 195 | z, logdet = self.forward(x) 196 | log_prob = self.base_dist.log_prob(z) + logdet 197 | return log_prob.unsqueeze(1) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flow Contrastive Estimation (FCE) 2 | 3 | This is an implementation of [Flow Contrastive Estimation](https://openaccess.thecvf.com/content_CVPR_2020/html/Gao_Flow_Contrastive_Estimation_of_Energy-Based_Models_CVPR_2020_paper.html) in PyTorch on 2D dataset. 4 | 5 | ## Introduction 6 | 7 | ### Direct Estimation of Energy Model is Difficult 8 | 9 | Our problem is to estimate an energy based model (EBM) 10 | 11 | $$p\_\theta(x) = \frac{\exp[-f\_\theta(x)]}{Z(\theta)}$$ 12 | 13 | where 14 | 15 | $$Z(\theta) = \int\exp[-f\_\theta(x)]dx$$ 16 | 17 | is the normalizing constant. The energy model specifies a probability distribution on data space. The normalizing constant is very difficult to calculate since we have to sum over an exponential number of configurations. 18 | 19 | The energy based model is implemented in file [ebm.py](ebm.py). 20 | 21 | ### NCE: Teach EBM to Classify Data and Noise 22 | 23 | One approach to estimate EBM is through [Noise Contrastive Estimation (NCE)]( http://proceedings.mlr.press/v9/gutmann10a/gutmann10a.pdf ). In NCE, the normalizing constant is treated as a trainable parameter, and the model parameters are estimated by training the EBM to classify data and noise. Let $p\_{\mathrm{data}}(x)$ denote data distribution and let $q(x)$ denote some noise distribution. This amounts to maximize the following posterior log-likelihood of the classification: 24 | 25 | $$V(\theta) = \mathbb{E}\_{x\sim p\_{\text{data}}}\log\frac{p\_\theta(x)}{p\_\theta(x)+q(x)} + \mathbb{E}\_{\tilde{x}\sim q}\log\frac{q(\tilde{x})}{p\_\theta(\tilde{x}) + q(\tilde{x})}.$$ 26 | 27 | ### FCE: Replace Noise in NCE with Flow Model 28 | 29 | In Flow Contrastive Estimation (FCE), we replace the noise $q(x)$ with a flow model $q_\alpha(x)$, and jointly train the two models by iteratively maximizing and minimizing the posterior log-likelihood of the classification: 30 | 31 | $$V(\alpha,\theta) = \mathbb{E}\_{x\sim p\_{\text{data}}}\log\frac{p\_\theta(x)}{p\_\theta(x)+q\_\alpha(x)} + \mathbb{E}\_{\tilde{x}\sim q\_\alpha}\log\frac{q\_\alpha(\tilde{x})}{p\_\theta(\tilde{x}) + q\_\alpha(\tilde{x})}.$$ 32 | 33 | This objective is implemented as the `value` function in file [util.py](util.py). 34 | 35 | When the classification accuracy is below a threshold, the energy model is trained. Otherwise, the flow model is trained. 36 | 37 | In the paper, the authors choose [Glow](https://arxiv.org/abs/1807.03039) as the flow model. In this repository we implemented both Glow and [MAF](https://arxiv.org/abs/1705.07057) as the flow model. 38 | 39 | ## Training 40 | 41 | To train the model, do 42 | 43 | ```shell 44 | python train.py 45 | ``` 46 | 47 | | Argument | Meaning | 48 | | ---------------------- | -------------------------------------------------------------------------------------------------------------------------------- | 49 | | `--seed=42` | random seed | 50 | | `--epoch=100` | training epoch | 51 | | `--flow=glow` | `glow` or `maf` to use as the flow model | 52 | | `--threshold=0.6` | threshold for alternate training | 53 | | `--batch=1000` | batch size | 54 | | `--dataset=8gaussians` | Available datasets:
- `8gaussians`
- `spiral`
- `2spirals`
- `checkerboard`
- `rings`
- `pinwheel` | 55 | | `--samples=500000` | training set size | 56 | | `--lr_ebm=1e-3` | Adam learning rate for EBM | 57 | | `--lr_flow=7e-4` | Adam learning rate for Flow model | 58 | | `--b1=0.9` | Adam gradient decay | 59 | | `--b2=0.999` | Adam gradient decay | 60 | 61 | ### Install wandb 62 | 63 | To run the script you need to install [Weights & Biases (wandb)](https://wandb.ai/site). It is an MLOps tool used to monitor the metrics during training. I find it very easy and convenient to use, and I encourage you to install and have a try as well. 64 | 65 | First, sign up on their website https://wandb.ai/site. You may use your GitHub account to sign up. Copy your API key. 66 | 67 | Next, install the Python package through 68 | 69 | ```shell 70 | pip install wandb 71 | ``` 72 | 73 | When you start running your experiment, you may be asked to login. You simply paste your API key and hit enter. Now that the experiment is running, you can go to https://app.wandb.ai/home and visualize your experiment in real time. 74 | 75 | Otherwise, if you do not wish to use it, you can comment out all the `wandb` part in [train.py](train.py). 76 | 77 | ## Visualizations 78 | 79 | ### Density Plots 80 | 81 | The figure below shows result of Flow Contrastive Estimation, using MAF as the flow model. The left column has three data distributions. The middle column shows densities learned by MAF. Note that they are also the densities that the energy model is trying to distinguish from the true densities. The right column shows densities learned by EBM. 82 | 83 | fce-maf 84 | 85 | For reference, here is the result presented in the FCE paper, showing learned Glow and EBM densities on three data distributions. 86 | 87 | fce-glow 88 | 89 | ### MSE 90 | 91 | In case of the 8 Gaussian dataset, we have an analytical formula for the true data distribution. We can evaluate the MSE of the log density learned by the energy model versus the true data distribution. The plot below shows the MSE on the `8gaussians` training dataset. 92 | 93 | mse 94 | 95 | ### Value 96 | 97 | The figure below shows the negative of the value function during training. If both EBM $p_\theta$ and Flow $q\_\alpha$ are close to the data distribution $p_{\text{data}}$, then $p\_\theta\approx q\_\alpha\approx p\_{\text{data}}$ and the value should be approximately 98 | 99 | $$- V(\alpha,\theta)\approx -\left(\log\frac{1}{2}+\log\frac{1}{2}\right) = \log4 \approx 1.39.$$ 100 | 101 | value 102 | 103 | ### Accuracy 104 | 105 | In our experiment we choose `0.6` as the default threshold, and as we can see the classification accuracy of the EBM fluctuates around 0.6. 106 | 107 | acc 108 | 109 | ## Reference 110 | 111 | Gao, Ruiqi, et al. "Flow contrastive estimation of energy-based models." *Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition*. 2020. -------------------------------------------------------------------------------- /flows/maf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Masked Autoregressive Flow for Density Estimation 3 | arXiv:1705.07057v4 4 | """ 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.distributions as D 10 | 11 | # -------------------- 12 | # Model layers and helpers 13 | # -------------------- 14 | def create_masks(input_size, hidden_size, n_hidden, input_order='sequential', input_degrees=None): 15 | # MADE paper sec 4: 16 | # degrees of connections between layers -- ensure at most in_degree - 1 connections 17 | degrees = [] 18 | 19 | # set input degrees to what is provided in args (the flipped order of the previous layer in a stack of mades); 20 | # else init input degrees based on strategy in input_order (sequential or random) 21 | if input_order == 'sequential': 22 | degrees += [torch.arange(input_size)] if input_degrees is None else [input_degrees] 23 | for _ in range(n_hidden + 1): 24 | degrees += [torch.arange(hidden_size) % (input_size - 1)] 25 | degrees += [torch.arange(input_size) % input_size - 1] if input_degrees is None else [input_degrees % input_size - 1] 26 | elif input_order == 'random': 27 | degrees += [torch.randperm(input_size)] if input_degrees is None else [input_degrees] 28 | for _ in range(n_hidden + 1): 29 | min_prev_degree = min(degrees[-1].min().item(), input_size - 1) 30 | degrees += [torch.randint(min_prev_degree, input_size, (hidden_size,))] 31 | min_prev_degree = min(degrees[-1].min().item(), input_size - 1) 32 | degrees += [torch.randint(min_prev_degree, input_size, (input_size,)) - 1] if input_degrees is None else [input_degrees - 1] 33 | 34 | # construct masks 35 | masks = [] 36 | for (d0, d1) in zip(degrees[:-1], degrees[1:]): 37 | masks += [(d1.unsqueeze(-1) >= d0.unsqueeze(0)).float()] 38 | 39 | return masks, degrees[0] 40 | 41 | 42 | 43 | class MaskedLinear(nn.Linear): 44 | """ MADE building block layer """ 45 | def __init__(self, input_size, n_outputs, mask, cond_label_size=None): 46 | super().__init__(input_size, n_outputs) 47 | 48 | self.register_buffer('mask', mask) 49 | 50 | self.cond_label_size = cond_label_size 51 | if cond_label_size is not None: 52 | self.cond_weight = nn.Parameter(torch.rand(n_outputs, cond_label_size) / math.sqrt(cond_label_size)) 53 | 54 | def forward(self, x, y=None): 55 | out = F.linear(x, self.weight * self.mask, self.bias) 56 | if y is not None: 57 | out = out + F.linear(y, self.cond_weight) 58 | return out 59 | 60 | def extra_repr(self): 61 | return 'in_features={}, out_features={}, bias={}'.format( 62 | self.in_features, self.out_features, self.bias is not None 63 | ) + (self.cond_label_size != None) * ', cond_features={}'.format(self.cond_label_size) 64 | 65 | 66 | 67 | class FlowSequential(nn.Sequential): 68 | """ Container for layers of a normalizing flow """ 69 | def forward(self, x, y): 70 | sum_log_abs_det_jacobians = 0 71 | for module in self: 72 | x, log_abs_det_jacobian = module(x, y) 73 | sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + log_abs_det_jacobian 74 | return x, sum_log_abs_det_jacobians 75 | 76 | def inverse(self, u, y): 77 | sum_log_abs_det_jacobians = 0 78 | for module in reversed(self): 79 | u, log_abs_det_jacobian = module.inverse(u, y) 80 | sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + log_abs_det_jacobian 81 | return u, sum_log_abs_det_jacobians 82 | 83 | 84 | class BatchNorm(nn.Module): 85 | """ RealNVP BatchNorm layer """ 86 | def __init__(self, input_size, momentum=0.9, eps=1e-5): 87 | super().__init__() 88 | self.momentum = momentum 89 | self.eps = eps 90 | 91 | self.log_gamma = nn.Parameter(torch.zeros(input_size)) 92 | self.beta = nn.Parameter(torch.zeros(input_size)) 93 | 94 | self.register_buffer('running_mean', torch.zeros(input_size)) 95 | self.register_buffer('running_var', torch.ones(input_size)) 96 | 97 | def forward(self, x, cond_y=None): 98 | if self.training: 99 | self.batch_mean = x.mean(0) 100 | self.batch_var = x.var(0) # note MAF paper uses biased variance estimate; ie x.var(0, unbiased=False) 101 | 102 | # update running mean 103 | self.running_mean.mul_(self.momentum).add_(self.batch_mean.data * (1 - self.momentum)) 104 | self.running_var.mul_(self.momentum).add_(self.batch_var.data * (1 - self.momentum)) 105 | 106 | mean = self.batch_mean 107 | var = self.batch_var 108 | else: 109 | mean = self.running_mean 110 | var = self.running_var 111 | 112 | # compute normalized input (cf original batch norm paper algo 1) 113 | x_hat = (x - mean) / torch.sqrt(var + self.eps) 114 | y = self.log_gamma.exp() * x_hat + self.beta 115 | 116 | # compute log_abs_det_jacobian (cf RealNVP paper) 117 | log_abs_det_jacobian = self.log_gamma - 0.5 * torch.log(var + self.eps) 118 | # print('in sum log var {:6.3f} ; out sum log var {:6.3f}; sum log det {:8.3f}; mean log_gamma {:5.3f}; mean beta {:5.3f}'.format( 119 | # (var + self.eps).log().sum().data.numpy(), y.var(0).log().sum().data.numpy(), log_abs_det_jacobian.mean(0).item(), self.log_gamma.mean(), self.beta.mean())) 120 | return y, log_abs_det_jacobian.expand_as(x) 121 | 122 | def inverse(self, y, cond_y=None): 123 | if self.training: 124 | mean = self.batch_mean 125 | var = self.batch_var 126 | else: 127 | mean = self.running_mean 128 | var = self.running_var 129 | 130 | x_hat = (y - self.beta) * torch.exp(-self.log_gamma) 131 | x = x_hat * torch.sqrt(var + self.eps) + mean 132 | 133 | log_abs_det_jacobian = 0.5 * torch.log(var + self.eps) - self.log_gamma 134 | 135 | return x, log_abs_det_jacobian.expand_as(x) 136 | 137 | 138 | # -------------------- 139 | # Models 140 | # -------------------- 141 | class MADE(nn.Module): 142 | def __init__(self, input_size, hidden_size, n_hidden, cond_label_size=None, activation='relu', input_order='sequential', input_degrees=None): 143 | """ 144 | Args: 145 | input_size -- scalar; dim of inputs 146 | hidden_size -- scalar; dim of hidden layers 147 | n_hidden -- scalar; number of hidden layers 148 | activation -- str; activation function to use 149 | input_order -- str or tensor; variable order for creating the autoregressive masks (sequential|random) 150 | or the order flipped from the previous layer in a stack of mades 151 | conditional -- bool; whether model is conditional 152 | """ 153 | super().__init__() 154 | # base distribution for calculation of log prob under the model 155 | self.register_buffer('base_dist_mean', torch.zeros(input_size)) 156 | self.register_buffer('base_dist_var', torch.ones(input_size)) 157 | 158 | # create masks 159 | masks, self.input_degrees = create_masks(input_size, hidden_size, n_hidden, input_order, input_degrees) 160 | 161 | # setup activation 162 | if activation == 'relu': 163 | activation_fn = nn.ReLU() 164 | elif activation == 'tanh': 165 | activation_fn = nn.Tanh() 166 | else: 167 | raise ValueError('Check activation function.') 168 | 169 | # construct model 170 | self.net_input = MaskedLinear(input_size, hidden_size, masks[0], cond_label_size) 171 | self.net = [] 172 | for m in masks[1:-1]: 173 | self.net += [activation_fn, MaskedLinear(hidden_size, hidden_size, m)] 174 | self.net += [activation_fn, MaskedLinear(hidden_size, 2 * input_size, masks[-1].repeat(2,1))] 175 | self.net = nn.Sequential(*self.net) 176 | 177 | @property 178 | def base_dist(self): 179 | return D.Normal(self.base_dist_mean, self.base_dist_var) 180 | 181 | def forward(self, x, y=None): 182 | # MAF eq 4 -- return mean and log std 183 | mu, alpha = self.net(self.net_input(x, y)).chunk(chunks=2, dim=1) 184 | u = (x - mu) * torch.exp(-alpha) 185 | # MAF eq 5 186 | logdet = - alpha 187 | return u, logdet 188 | 189 | def inverse(self, u, y=None, sum_log_abs_det_jacobians=None): 190 | # MAF eq 3 191 | # D = u.shape[1] 192 | x = torch.zeros_like(u) 193 | # run through reverse model 194 | for i in self.input_degrees: 195 | mu, alpha = self.net(self.net_input(x.clone(), y)).chunk(chunks=2, dim=1) 196 | x[:,i] = mu[:,i] + u[:,i] * torch.exp(alpha[:,i]) 197 | logdet = alpha 198 | return x, logdet 199 | 200 | def log_prob(self, x, y=None): 201 | u, logdet = self.forward(x, y) 202 | return torch.sum(self.base_dist.log_prob(u) + logdet, dim=1, keepdim=True) 203 | 204 | 205 | 206 | class MAF(nn.Module): 207 | def __init__(self, n_blocks, input_size, hidden_size, n_hidden, cond_label_size=None, activation='relu', input_order='sequential', batch_norm=True): 208 | super().__init__() 209 | # base distribution for calculation of log prob under the model 210 | self.register_buffer('base_dist_mean', torch.zeros(input_size)) 211 | self.register_buffer('base_dist_var', torch.ones(input_size)) 212 | 213 | 214 | # construct model 215 | modules = [] 216 | self.input_degrees = None 217 | for i in range(n_blocks): 218 | modules += [MADE(input_size, hidden_size, n_hidden, cond_label_size, activation, input_order, self.input_degrees)] 219 | self.input_degrees = modules[-1].input_degrees.flip(0) 220 | # modules += batch_norm * [BatchNorm(input_size)] 221 | 222 | self.net = FlowSequential(*modules) 223 | 224 | @property 225 | def base_dist(self): 226 | return D.Normal(self.base_dist_mean, self.base_dist_var) 227 | 228 | def forward(self, x, y=None): 229 | return self.net(x, y) 230 | 231 | def inverse(self, u, y=None): 232 | gen, logdet = self.net.inverse(u, y) 233 | 234 | # get logq(x̃), where x̃ = f^{-1}(z) 235 | logq_gen = torch.sum(self.base_dist.log_prob(u) - logdet, dim=1, keepdim=True) 236 | 237 | return gen, logq_gen 238 | 239 | 240 | def log_prob(self, x, y=None): 241 | u, logdet = self.forward(x, y) 242 | return torch.sum(self.base_dist.log_prob(u) + logdet, dim=1, keepdim=True) -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import DataLoader 5 | import torch.distributions as D 6 | import torch.nn.functional as F 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | font = {'family' : 'sans-serif', 12 | 'weight' : 'bold', 13 | 'size' : 16} 14 | matplotlib.rc('font', **font) 15 | 16 | # ------------------------------ 17 | # FCE Value Function 18 | # ------------------------------ 19 | def value(energy, flow, x, z, maximize=True): 20 | gen, logq_gen = flow.inverse(z) # the second term is logq(x̃) 21 | logp_x = energy(x) # logp(x) 22 | logp_gen = energy(gen) #logp(x̃) 23 | logq_x = flow.log_prob(x) # logq(x) 24 | 25 | value_x = logp_x - torch.logsumexp(torch.cat([logp_x, logq_x], dim=1), dim=1, keepdim=True) # logp(x)/(logp(x) + logq(x)) 26 | value_gen = logq_gen - torch.logsumexp(torch.cat([logp_gen, logq_gen], dim=1), dim=1, keepdim=True) # logq(x̃)/(logp(x̃) + logq(x̃)) 27 | 28 | v = value_x.mean() + value_gen.mean() 29 | 30 | # calculate accuracy 31 | r_x = torch.sigmoid(logp_x - logq_x) 32 | r_gen = torch.sigmoid(logq_gen - logp_gen) 33 | acc = ((r_x >= 1/2).sum() + (r_gen > 1/2).sum()).cpu().numpy() / (len(r_x) + len(r_gen)) 34 | 35 | if maximize: 36 | return -v, acc 37 | else: 38 | return v, acc 39 | 40 | # ------------------------------ 41 | # MIXED GAUSSIAN DENSITY 42 | # ------------------------------ 43 | class MixedGaussian(): 44 | def __init__(self, device): 45 | r = 2. * math.sqrt(2) 46 | self.covariance_matrix = (1/8 * torch.eye(2)).to(device) # variance on the diagonal, not standard deviation 47 | self.m1 = D.MultivariateNormal(loc=torch.tensor([r, 0.], device=device), covariance_matrix=self.covariance_matrix) 48 | self.m2 = D.MultivariateNormal(loc=torch.tensor([-r, 0.], device=device), covariance_matrix=self.covariance_matrix) 49 | self.m3 = D.MultivariateNormal(loc=torch.tensor([0., r], device=device), covariance_matrix=self.covariance_matrix) 50 | self.m4 = D.MultivariateNormal(loc=torch.tensor([0., -r], device=device), covariance_matrix=self.covariance_matrix) 51 | self.m5 = D.MultivariateNormal(loc=torch.tensor([2., 2.], device=device), covariance_matrix=self.covariance_matrix) 52 | self.m6 = D.MultivariateNormal(loc=torch.tensor([-2., 2.], device=device), covariance_matrix=self.covariance_matrix) 53 | self.m7 = D.MultivariateNormal(loc=torch.tensor([2., -2.], device=device), covariance_matrix=self.covariance_matrix) 54 | self.m8 = D.MultivariateNormal(loc=torch.tensor([-2., -2.], device=device), covariance_matrix=self.covariance_matrix) 55 | 56 | def log_prob(self, x): 57 | log_probs = torch.cat([self.m1.log_prob(x).unsqueeze(1), 58 | self.m2.log_prob(x).unsqueeze(1), 59 | self.m3.log_prob(x).unsqueeze(1), 60 | self.m4.log_prob(x).unsqueeze(1), 61 | self.m5.log_prob(x).unsqueeze(1), 62 | self.m6.log_prob(x).unsqueeze(1), 63 | self.m7.log_prob(x).unsqueeze(1), 64 | self.m8.log_prob(x).unsqueeze(1)], dim=1) 65 | out = - math.log(8.) + torch.logsumexp(log_probs, dim=1, keepdim=True) 66 | return out 67 | 68 | # MSE between EBM and a true distribution. 69 | def mse(ebm, true_dist, data_batch): 70 | with torch.no_grad(): 71 | ebm_outputs = ebm(data_batch) 72 | true_values = true_dist.log_prob(data_batch) 73 | return F.mse_loss(input=ebm_outputs, target=true_values).item() 74 | 75 | #------------------------------------------- 76 | # DATA 77 | #------------------------------------------- 78 | def get_data(args): 79 | dataset = sample_2d_data(dataset=args.dataset, n_samples=args.samples) 80 | dataloader = DataLoader(dataset, batch_size=args.batch, shuffle=True) 81 | return dataset, dataloader 82 | 83 | 84 | def sample_2d_data(dataset='8gaussians', n_samples=50000): 85 | 86 | z = torch.randn(n_samples, 2) 87 | 88 | if dataset == '8gaussians': 89 | scale = 4 90 | sq2 = 1/math.sqrt(2) 91 | centers = [(1,0), (-1,0), (0,1), (0,-1), (sq2,sq2), (-sq2,sq2), (sq2,-sq2), (-sq2,-sq2)] 92 | centers = torch.tensor([(scale * x, scale * y) for x,y in centers]) 93 | return sq2 * (0.5 * z + centers[torch.randint(len(centers), size=(n_samples,))]) 94 | # return 0.05 * z + centers[torch.randint(len(centers), size=(n_samples,))] 95 | 96 | elif dataset == '2spirals': 97 | n = torch.sqrt(torch.rand(n_samples // 2)) * 540 * (2 * math.pi) / 360 98 | d1x = - torch.cos(n) * n + torch.rand(n_samples // 2) * 0.5 99 | d1y = torch.sin(n) * n + torch.rand(n_samples // 2) * 0.5 100 | x = torch.cat([torch.stack([ d1x, d1y], dim=1), 101 | torch.stack([-d1x, -d1y], dim=1)], dim=0) / 3 102 | return x + 0.1*z 103 | 104 | elif dataset == 'spiral': 105 | n = torch.sqrt(torch.rand(n_samples)) * 540 * (2 * math.pi) / 360 106 | d1x = - torch.cos(n) * n + torch.rand(n_samples) * 0.5 107 | d1y = torch.sin(n) * n + torch.rand(n_samples) * 0.5 108 | print(d1x.shape) 109 | x = torch.stack([ d1x, d1y], dim=1) / 3 110 | print(x.shape) 111 | return x + 0.1*z 112 | 113 | elif dataset == 'checkerboard': 114 | x1 = torch.rand(n_samples) * 4 - 2 115 | x2_ = torch.rand(n_samples) - torch.randint(0, 2, (n_samples,), dtype=torch.float) * 2 116 | x2 = x2_ + x1.floor() % 2 117 | return torch.stack([x1, x2], dim=1) * 2 118 | 119 | elif dataset == 'rings': 120 | n_samples4 = n_samples3 = n_samples2 = n_samples // 4 121 | n_samples1 = n_samples - n_samples4 - n_samples3 - n_samples2 122 | # so as not to have the first point = last point, set endpoint=False in np; here shifted by one 123 | linspace4 = torch.linspace(0, 2 * math.pi, n_samples4 + 1)[:-1] 124 | linspace3 = torch.linspace(0, 2 * math.pi, n_samples3 + 1)[:-1] 125 | linspace2 = torch.linspace(0, 2 * math.pi, n_samples2 + 1)[:-1] 126 | linspace1 = torch.linspace(0, 2 * math.pi, n_samples1 + 1)[:-1] 127 | circ4_x = torch.cos(linspace4) 128 | circ4_y = torch.sin(linspace4) 129 | circ3_x = torch.cos(linspace4) * 0.75 130 | circ3_y = torch.sin(linspace3) * 0.75 131 | circ2_x = torch.cos(linspace2) * 0.5 132 | circ2_y = torch.sin(linspace2) * 0.5 133 | circ1_x = torch.cos(linspace1) * 0.25 134 | circ1_y = torch.sin(linspace1) * 0.25 135 | 136 | x = torch.stack([torch.cat([circ4_x, circ3_x, circ2_x, circ1_x]), 137 | torch.cat([circ4_y, circ3_y, circ2_y, circ1_y])], dim=1) * 3.0 138 | 139 | # random sample 140 | x = x[torch.randint(0, n_samples, size=(n_samples,))] 141 | 142 | # Add noise 143 | return x + torch.normal(mean=torch.zeros_like(x), std=0.08*torch.ones_like(x)) 144 | 145 | elif dataset == "pinwheel": 146 | rng = np.random.RandomState() 147 | radial_std = 0.3 148 | tangential_std = 0.1 149 | num_classes = 5 150 | num_per_class = n_samples // 5 151 | rate = 0.25 152 | rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False) 153 | features = rng.randn(num_classes*num_per_class, 2) * np.array([radial_std, tangential_std]) 154 | # features = np.random.randn(num_classes*num_per_class, 2) * np.array([radial_std, tangential_std]) 155 | features[:, 0] += 1. 156 | labels = np.repeat(np.arange(num_classes), num_per_class) 157 | 158 | angles = rads[labels] + rate * np.exp(features[:, 0]) 159 | rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)]) 160 | rotations = np.reshape(rotations.T, (-1, 2, 2)) 161 | 162 | data = 2 * rng.permutation(np.einsum("ti,tij->tj", features, rotations)) 163 | # data = 2 * np.random.permutation(np.einsum("ti,tij->tj", features, rotations)) 164 | return torch.as_tensor(data, dtype=torch.float32) 165 | 166 | else: 167 | raise RuntimeError('Invalid `dataset` to sample from.') 168 | 169 | # -------------------- 170 | # Plotting 171 | # -------------------- 172 | 173 | @torch.no_grad() 174 | def plot(dataset, energy, flow, epoch, device): 175 | n_pts = 1000 176 | range_lim = 4 177 | # construct test points 178 | test_grid = setup_grid(range_lim, n_pts, device) 179 | 180 | # plot 181 | fig, axs = plt.subplots(2, 2, figsize=(8,8), subplot_kw={'aspect': 'equal'}) 182 | plot_samples(dataset, axs[0][0], range_lim, n_pts) 183 | plot_flow(flow, axs[0][1], test_grid, n_pts) 184 | plot_flow_samples(flow, axs[1][0], range_lim, n_pts) 185 | plot_energy(energy, axs[1][1], test_grid, n_pts) 186 | 187 | # format 188 | for ax in plt.gcf().axes: format_ax(ax, range_lim) 189 | plt.tight_layout(pad=2.0) 190 | 191 | # save 192 | print('Saving image to images/....') 193 | plt.savefig('images/epoch_{}.png'.format(epoch)) 194 | plt.close() 195 | 196 | 197 | def setup_grid(range_lim, n_pts, device): 198 | x = torch.linspace(-range_lim, range_lim, n_pts) 199 | xx, yy = torch.meshgrid((x, x)) 200 | zz = torch.stack((xx.flatten(), yy.flatten()), dim=1) 201 | return xx, yy, zz.to(device) 202 | 203 | def plot_samples(dataset, ax, range_lim, n_pts): 204 | samples = dataset.numpy() 205 | ax.hist2d(samples[:,0], samples[:,1], range=[[-range_lim, range_lim], [-range_lim, range_lim]], bins=n_pts, cmap=plt.cm.jet) 206 | ax.set_title('Data') 207 | 208 | def plot_energy(energy, ax, test_grid, n_pts): 209 | xx, yy, zz = test_grid 210 | log_prob = energy(zz) 211 | prob = log_prob.exp().cpu() 212 | ax.pcolormesh(xx, yy, prob.view(n_pts,n_pts), cmap=plt.cm.jet) 213 | ax.set_facecolor(plt.cm.jet(0.)) 214 | ax.set_title('EBM') 215 | 216 | def plot_flow(flow, ax, test_grid, n_pts): 217 | flow.eval() 218 | xx, yy, zz = test_grid 219 | log_prob = flow.log_prob(zz) 220 | prob = log_prob.exp().cpu() 221 | ax.pcolormesh(xx, yy, prob.view(n_pts,n_pts), cmap=plt.cm.jet) 222 | ax.set_facecolor(plt.cm.jet(0.)) 223 | ax.set_title('Flow') 224 | 225 | def plot_flow_samples(flow, ax, range_lim, n_pts): 226 | z = flow.base_dist.sample((10000,)) 227 | samples, _ = flow.inverse(z) 228 | samples = samples.cpu().numpy() 229 | ax.hist2d(samples[:,0], samples[:,1], range=[[-range_lim, range_lim], [-range_lim, range_lim]], bins=n_pts, cmap=plt.cm.jet) 230 | ax.set_title('Flow-samples') 231 | 232 | def format_ax(ax, range_lim): 233 | ax.set_xlim(-range_lim, range_lim) 234 | ax.set_ylim(-range_lim, range_lim) 235 | ax.get_xaxis().set_visible(False) 236 | ax.get_yaxis().set_visible(False) 237 | ax.invert_yaxis() --------------------------------------------------------------------------------